"...text-generation-inference.git" did not exist on "3238c49121b02432bf2938c6ebfd44f06c5adc2f"
Unverified Commit 7b2407f4 authored by Sean Sube's avatar Sean Sube Committed by GitHub
Browse files

add support for pre-calculated prompt embeds to Stable Diffusion ONNX pipelines (#2597)

* add support for prompt embeds to SD ONNX pipeline

* fix up the pipeline copies

* add prompt embeds param to other ONNX pipelines

* fix up prompt embeds param for SD upscaling ONNX pipeline

* add missing type annotations to ONNX pipes
parent 639f6455
...@@ -111,7 +111,15 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -111,7 +111,15 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
) )
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: Optional[int],
do_classifier_free_guidance: bool,
negative_prompt: Optional[str],
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -125,32 +133,48 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -125,32 +133,48 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
negative_prompt (`str` or `List[str]`): negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`). if `guidance_scale` is less than `1`).
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`np.ndarray`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
# get prompt text embeddings elif prompt is not None and isinstance(prompt, list):
text_inputs = self.tokenizer( batch_size = len(prompt)
prompt, else:
padding="max_length", batch_size = prompt_embeds.shape[0]
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if not np.array_equal(text_input_ids, untruncated_ids): if prompt_embeds is None:
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) # get prompt text embeddings
logger.warning( text_inputs = self.tokenizer(
"The following part of your input was truncated because CLIP can only handle sequences up to" prompt,
f" {self.tokenizer.model_max_length} tokens: {removed_text}" padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -170,7 +194,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -170,7 +194,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -179,6 +203,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -179,6 +203,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
return_tensors="np", return_tensors="np",
) )
negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
if do_classifier_free_guidance:
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
...@@ -188,9 +214,56 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -188,9 +214,56 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
return prompt_embeds return prompt_embeds
def __call__( def check_inputs(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
height: Optional[int],
width: Optional[int],
callback_steps: int,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = 512, height: Optional[int] = 512,
width: Optional[int] = 512, width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
...@@ -200,28 +273,86 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -200,28 +273,86 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[np.random.RandomState] = None, generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None, latents: Optional[np.ndarray] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
): ):
if isinstance(prompt, str): r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
`Image`, or tensor representing an image batch which will be upscaled. *
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
One or a list of [numpy generator(s)](TODO) to make generation deterministic.
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`np.ndarray`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
# define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt) batch_size = len(prompt)
else: else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") batch_size = prompt_embeds.shape[0]
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if generator is None: if generator is None:
generator = np.random generator = np.random
...@@ -232,7 +363,12 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): ...@@ -232,7 +363,12 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
......
...@@ -161,7 +161,15 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -161,7 +161,15 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: Optional[int],
do_classifier_free_guidance: bool,
negative_prompt: Optional[str],
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -175,32 +183,48 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -175,32 +183,48 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
negative_prompt (`str` or `List[str]`): negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`). if `guidance_scale` is less than `1`).
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`np.ndarray`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
# get prompt text embeddings elif prompt is not None and isinstance(prompt, list):
text_inputs = self.tokenizer( batch_size = len(prompt)
prompt, else:
padding="max_length", batch_size = prompt_embeds.shape[0]
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if not np.array_equal(text_input_ids, untruncated_ids): if prompt_embeds is None:
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) # get prompt text embeddings
logger.warning( text_inputs = self.tokenizer(
"The following part of your input was truncated because CLIP can only handle sequences up to" prompt,
f" {self.tokenizer.model_max_length} tokens: {removed_text}" padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -220,7 +244,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -220,7 +244,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -229,6 +253,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -229,6 +253,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
return_tensors="np", return_tensors="np",
) )
negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
if do_classifier_free_guidance:
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
...@@ -238,6 +264,48 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -238,6 +264,48 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
return prompt_embeds return prompt_embeds
def check_inputs(
self,
prompt: Union[str, List[str]],
callback_steps: int,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
...@@ -249,6 +317,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -249,6 +317,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[np.random.RandomState] = None, generator: Optional[np.random.RandomState] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
...@@ -288,6 +358,13 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -288,6 +358,13 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
[`schedulers.DDIMScheduler`], will be ignored for others. [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*): generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic. A np.random.RandomState to make generation deterministic.
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`np.ndarray`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -308,24 +385,21 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -308,24 +385,21 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
if isinstance(prompt, str):
# check inputs. Raise error if not correct
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
# define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt) batch_size = len(prompt)
else: else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") batch_size = prompt_embeds.shape[0]
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if generator is None: if generator is None:
generator = np.random generator = np.random
...@@ -340,7 +414,12 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -340,7 +414,12 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
latents_dtype = prompt_embeds.dtype latents_dtype = prompt_embeds.dtype
......
...@@ -162,7 +162,15 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -162,7 +162,15 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: Optional[int],
do_classifier_free_guidance: bool,
negative_prompt: Optional[str],
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -176,32 +184,48 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -176,32 +184,48 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
negative_prompt (`str` or `List[str]`): negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`). if `guidance_scale` is less than `1`).
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`np.ndarray`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
# get prompt text embeddings elif prompt is not None and isinstance(prompt, list):
text_inputs = self.tokenizer( batch_size = len(prompt)
prompt, else:
padding="max_length", batch_size = prompt_embeds.shape[0]
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if not np.array_equal(text_input_ids, untruncated_ids): if prompt_embeds is None:
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) # get prompt text embeddings
logger.warning( text_inputs = self.tokenizer(
"The following part of your input was truncated because CLIP can only handle sequences up to" prompt,
f" {self.tokenizer.model_max_length} tokens: {removed_text}" padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -221,7 +245,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -221,7 +245,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -230,6 +254,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -230,6 +254,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
return_tensors="np", return_tensors="np",
) )
negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
if do_classifier_free_guidance:
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
...@@ -239,6 +265,54 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -239,6 +265,54 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
return prompt_embeds return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs
def check_inputs(
self,
prompt: Union[str, List[str]],
height: Optional[int],
width: Optional[int],
callback_steps: int,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
...@@ -254,6 +328,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -254,6 +328,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
eta: float = 0.0, eta: float = 0.0,
generator: Optional[np.random.RandomState] = None, generator: Optional[np.random.RandomState] = None,
latents: Optional[np.ndarray] = None, latents: Optional[np.ndarray] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
...@@ -300,6 +376,13 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -300,6 +376,13 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`. tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`np.ndarray`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -321,23 +404,18 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -321,23 +404,18 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
if isinstance(prompt, str): # check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
# define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt) batch_size = len(prompt)
else: else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") batch_size = prompt_embeds.shape[0]
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if generator is None: if generator is None:
generator = np.random generator = np.random
...@@ -351,7 +429,12 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -351,7 +429,12 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
num_channels_latents = NUM_LATENT_CHANNELS num_channels_latents = NUM_LATENT_CHANNELS
......
...@@ -147,7 +147,15 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -147,7 +147,15 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: Optional[int],
do_classifier_free_guidance: bool,
negative_prompt: Optional[str],
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -161,32 +169,48 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -161,32 +169,48 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
negative_prompt (`str` or `List[str]`): negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`). if `guidance_scale` is less than `1`).
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`np.ndarray`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
""" """
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt is not None and isinstance(prompt, str):
batch_size = 1
# get prompt text embeddings elif prompt is not None and isinstance(prompt, list):
text_inputs = self.tokenizer( batch_size = len(prompt)
prompt, else:
padding="max_length", batch_size = prompt_embeds.shape[0]
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if not np.array_equal(text_input_ids, untruncated_ids): if prompt_embeds is None:
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) # get prompt text embeddings
logger.warning( text_inputs = self.tokenizer(
"The following part of your input was truncated because CLIP can only handle sequences up to" prompt,
f" {self.tokenizer.model_max_length} tokens: {removed_text}" padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -206,7 +230,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -206,7 +230,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
else: else:
uncond_tokens = negative_prompt uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1] max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer( uncond_input = self.tokenizer(
uncond_tokens, uncond_tokens,
padding="max_length", padding="max_length",
...@@ -215,6 +239,8 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -215,6 +239,8 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return_tensors="np", return_tensors="np",
) )
negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
if do_classifier_free_guidance:
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
...@@ -224,6 +250,48 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -224,6 +250,48 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return prompt_embeds return prompt_embeds
def check_inputs(
self,
prompt,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
...@@ -236,6 +304,8 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -236,6 +304,8 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0, eta: Optional[float] = 0.0,
generator: Optional[np.random.RandomState] = None, generator: Optional[np.random.RandomState] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
...@@ -280,6 +350,13 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -280,6 +350,13 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
[`schedulers.DDIMScheduler`], will be ignored for others. [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*): generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic. A np.random.RandomState to make generation deterministic.
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`np.ndarray`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -300,24 +377,21 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -300,24 +377,21 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
if isinstance(prompt, str):
# check inputs. Raise error if not correct
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
# define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif isinstance(prompt, list): elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt) batch_size = len(prompt)
else: else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") batch_size = prompt_embeds.shape[0]
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if generator is None: if generator is None:
generator = np.random generator = np.random
...@@ -333,7 +407,12 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -333,7 +407,12 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
latents_dtype = prompt_embeds.dtype latents_dtype = prompt_embeds.dtype
......
...@@ -70,16 +70,85 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): ...@@ -70,16 +70,85 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
): ):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
noise_level TODO
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`np.ndarray`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 1. Check inputs # 1. Check inputs
self.check_inputs(prompt, image, noise_level, callback_steps) self.check_inputs(prompt, image, noise_level, callback_steps)
# 2. Define call parameters # 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
...@@ -88,7 +157,13 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): ...@@ -88,7 +157,13 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# 3. Encode input prompt # 3. Encode input prompt
text_embeddings = self._encode_prompt( text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
) )
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)] latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
...@@ -199,45 +274,59 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): ...@@ -199,45 +274,59 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
image = image.transpose((0, 2, 3, 1)) image = image.transpose((0, 2, 3, 1))
return image return image
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(
batch_size = len(prompt) if isinstance(prompt, list) else 1 self,
prompt: Union[str, List[str]],
text_inputs = self.tokenizer( device,
prompt, num_images_per_prompt: Optional[int],
padding="max_length", do_classifier_free_guidance: bool,
max_length=self.tokenizer.model_max_length, negative_prompt: Optional[str],
truncation=True, prompt_embeds: Optional[torch.FloatTensor] = None,
return_tensors="pt", negative_prompt_embeds: Optional[torch.FloatTensor] = None,
) ):
text_input_ids = text_inputs.input_ids if prompt is not None and isinstance(prompt, str):
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids batch_size = 1
elif prompt is not None and isinstance(prompt, list):
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): batch_size = len(prompt)
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) else:
logger.warning( batch_size = prompt_embeds.shape[0]
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}" if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
# if hasattr(text_inputs, "attention_mask"): # no positional arguments to text_encoder
# attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder(
# else: input_ids=text_input_ids.int().to(device),
# attention_mask = None # attention_mask=attention_mask,
)
# no positional arguments to text_encoder prompt_embeds = prompt_embeds[0]
text_embeddings = self.text_encoder(
input_ids=text_input_ids.int().to(device),
# attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
bs_embed, seq_len, _ = text_embeddings.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
text_embeddings = text_embeddings.reshape(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.reshape(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance: if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str] uncond_tokens: List[str]
if negative_prompt is None: if negative_prompt is None:
uncond_tokens = [""] * batch_size uncond_tokens = [""] * batch_size
...@@ -277,6 +366,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): ...@@ -277,6 +366,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
) )
uncond_embeddings = uncond_embeddings[0] uncond_embeddings = uncond_embeddings[0]
if do_classifier_free_guidance:
seq_len = uncond_embeddings.shape[1] seq_len = uncond_embeddings.shape[1]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
...@@ -285,6 +375,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): ...@@ -285,6 +375,6 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes # to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) prompt_embeds = np.concatenate([uncond_embeddings, prompt_embeds])
return text_embeddings return prompt_embeds
...@@ -133,6 +133,76 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes ...@@ -133,6 +133,76 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_prompt_embeds(self):
pipe = OnnxStableDiffusionPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs()
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
inputs = self.get_dummy_inputs()
prompt = 3 * [inputs.pop("prompt")]
text_inputs = pipe.tokenizer(
prompt,
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_inputs = text_inputs["input_ids"]
prompt_embeds = pipe.text_encoder(input_ids=text_inputs.astype(np.int32))[0]
inputs["prompt_embeds"] = prompt_embeds
# forward
output = pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_stable_diffusion_negative_prompt_embeds(self):
pipe = OnnxStableDiffusionPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs()
negative_prompt = 3 * ["this is a negative prompt"]
inputs["negative_prompt"] = negative_prompt
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
inputs = self.get_dummy_inputs()
prompt = 3 * [inputs.pop("prompt")]
embeds = []
for p in [prompt, negative_prompt]:
text_inputs = pipe.tokenizer(
p,
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_inputs = text_inputs["input_ids"]
embeds.append(pipe.text_encoder(input_ids=text_inputs.astype(np.int32))[0])
inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
# forward
output = pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
@nightly @nightly
@require_onnxruntime @require_onnxruntime
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment