"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f2d348d9043d9648baedf4bfaeb345aee3529176"
Unverified Commit 86aa747d authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Fix ONNX conversion and inference (#1416)

parent d52388f4
...@@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F ...@@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
) )
del pipeline.safety_checker del pipeline.safety_checker
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker") safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
feature_extractor = pipeline.feature_extractor
else: else:
safety_checker = None safety_checker = None
feature_extractor = None
onnx_pipeline = OnnxStableDiffusionPipeline( onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"), vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
...@@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F ...@@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler, scheduler=pipeline.scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=pipeline.feature_extractor, feature_extractor=feature_extractor,
requires_safety_checker=safety_checker is not None,
) )
onnx_pipeline.save_pretrained(output_path) onnx_pipeline.save_pretrained(output_path)
......
...@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union ...@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -42,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -42,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
vae_encoder: OnnxRuntimeModel, vae_encoder: OnnxRuntimeModel,
...@@ -99,27 +100,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -99,27 +100,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
...@@ -130,7 +110,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -130,7 +110,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
...@@ -213,8 +192,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -213,8 +192,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
height: Optional[int] = None, height: Optional[int] = 512,
width: Optional[int] = None, width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
...@@ -228,10 +207,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -228,10 +207,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs, **kwargs,
): ):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif isinstance(prompt, list): elif isinstance(prompt, list):
...@@ -264,12 +239,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -264,12 +239,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
latents_shape = ( latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
batch_size * num_images_per_prompt,
4,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None: if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype) latents = generator.randn(*latents_shape).astype(latents_dtype)
elif latents.shape != latents_shape: elif latents.shape != latents_shape:
......
...@@ -19,7 +19,6 @@ import numpy as np ...@@ -19,7 +19,6 @@ import numpy as np
import torch import torch
import PIL import PIL
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -78,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -78,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
vae_encoder: OnnxRuntimeModel, vae_encoder: OnnxRuntimeModel,
...@@ -135,27 +136,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -135,27 +136,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
......
...@@ -19,7 +19,6 @@ import numpy as np ...@@ -19,7 +19,6 @@ import numpy as np
import torch import torch
import PIL import PIL
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
...@@ -91,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -91,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
vae_encoder: OnnxRuntimeModel, vae_encoder: OnnxRuntimeModel,
...@@ -149,27 +150,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -149,27 +150,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae_encoder=vae_encoder, vae_encoder=vae_encoder,
vae_decoder=vae_decoder, vae_decoder=vae_decoder,
...@@ -180,7 +160,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -180,7 +160,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
...@@ -267,8 +246,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -267,8 +246,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
image: PIL.Image.Image, image: PIL.Image.Image,
mask_image: PIL.Image.Image, mask_image: PIL.Image.Image,
height: Optional[int] = None, height: Optional[int] = 512,
width: Optional[int] = None, width: Optional[int] = 512,
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
...@@ -296,9 +275,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -296,9 +275,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`. instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image. The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
...@@ -343,9 +322,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -343,9 +322,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -381,12 +357,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -381,12 +357,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
) )
num_channels_latents = NUM_LATENT_CHANNELS num_channels_latents = NUM_LATENT_CHANNELS
latents_shape = ( latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
batch_size * num_images_per_prompt,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
latents_dtype = text_embeddings.dtype latents_dtype = text_embeddings.dtype
if latents is None: if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype) latents = generator.randn(*latents_shape).astype(latents_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment