Unverified Commit 728a3f3e authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Rename StableDiffusionOnnxPipeline -> OnnxStableDiffusionPipeline (#887)

Rename and deprecate
parent 100e094c
...@@ -58,7 +58,7 @@ else: ...@@ -58,7 +58,7 @@ else:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403 from .utils.dummy_torch_and_transformers_objects import * # noqa F403
if is_torch_available() and is_transformers_available() and is_onnx_available(): if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import StableDiffusionOnnxPipeline from .pipelines import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
else: else:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
......
...@@ -20,7 +20,7 @@ if is_torch_available() and is_transformers_available(): ...@@ -20,7 +20,7 @@ if is_torch_available() and is_transformers_available():
) )
if is_transformers_available() and is_onnx_available(): if is_transformers_available() and is_onnx_available():
from .stable_diffusion import StableDiffusionOnnxPipeline from .stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
if is_transformers_available() and is_flax_available(): if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline from .stable_diffusion import FlaxStableDiffusionPipeline
...@@ -34,7 +34,7 @@ if is_transformers_available() and is_torch_available(): ...@@ -34,7 +34,7 @@ if is_transformers_available() and is_torch_available():
from .safety_checker import StableDiffusionSafetyChecker from .safety_checker import StableDiffusionSafetyChecker
if is_transformers_available() and is_onnx_available(): if is_transformers_available() and is_onnx_available():
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
if is_transformers_available() and is_flax_available(): if is_transformers_available() and is_flax_available():
import flax import flax
......
...@@ -8,14 +8,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer ...@@ -8,14 +8,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...onnx_utils import OnnxRuntimeModel from ...onnx_utils import OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class StableDiffusionOnnxPipeline(DiffusionPipeline): class OnnxStableDiffusionPipeline(DiffusionPipeline):
vae_decoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel
tokenizer: CLIPTokenizer tokenizer: CLIPTokenizer
...@@ -198,3 +198,27 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): ...@@ -198,3 +198,27 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline):
def __init__(
self,
vae_decoder: OnnxRuntimeModel,
text_encoder: OnnxRuntimeModel,
tokenizer: CLIPTokenizer,
unet: OnnxRuntimeModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
):
deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`."
deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message)
super().__init__(
vae_decoder=vae_decoder,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
...@@ -4,6 +4,21 @@ ...@@ -4,6 +4,21 @@
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
class OnnxStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers", "onnx"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers", "onnx"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "onnx"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers", "onnx"])
class StableDiffusionOnnxPipeline(metaclass=DummyObject): class StableDiffusionOnnxPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers", "onnx"] _backends = ["torch", "transformers", "onnx"]
......
...@@ -37,13 +37,13 @@ from diffusers import ( ...@@ -37,13 +37,13 @@ from diffusers import (
LDMPipeline, LDMPipeline,
LDMTextToImagePipeline, LDMTextToImagePipeline,
LMSDiscreteScheduler, LMSDiscreteScheduler,
OnnxStableDiffusionPipeline,
PNDMPipeline, PNDMPipeline,
PNDMScheduler, PNDMScheduler,
ScoreSdeVePipeline, ScoreSdeVePipeline,
ScoreSdeVeScheduler, ScoreSdeVeScheduler,
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline, StableDiffusionInpaintPipeline,
StableDiffusionOnnxPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
UNet2DConditionModel, UNet2DConditionModel,
UNet2DModel, UNet2DModel,
...@@ -2010,7 +2010,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -2010,7 +2010,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_stable_diffusion_onnx(self): def test_stable_diffusion_onnx(self):
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained( sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider" "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
) )
...@@ -2214,7 +2214,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -2214,7 +2214,7 @@ class PipelineTesterMixin(unittest.TestCase):
test_callback_fn.has_been_called = False test_callback_fn.has_been_called = False
pipe = StableDiffusionOnnxPipeline.from_pretrained( pipe = OnnxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider" "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CPUExecutionProvider"
) )
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment