Unverified Commit d43972ae authored by Jorge C. Gomes's avatar Jorge C. Gomes Committed by GitHub
Browse files

Fixes prompt input checks in StableDiffusion img2img pipeline (#2206)

* Fixes prompt input checks in img2img

Allows providing prompt_embeds instead of the prompt, which is not currently possible as the first check fails.
This becomes the same as the function found in https://github.com/huggingface/diffusers/blob/8267c7844504b55366525169187767ef92d1f499/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L393

* Continues the fix

This also needs to be fixed. Becomes consistent with https://github.com/huggingface/diffusers/blob/8267c7844504b55366525169187767ef92d1f499/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L558

I've now tested this implementation, and it produces the expected results.
parent ffed2420
...@@ -428,9 +428,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -428,9 +428,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
def check_inputs( def check_inputs(
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
): ):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
...@@ -623,7 +620,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -623,7 +620,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
# 2. Define call parameters # 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment