Unverified Commit d041dd50 authored by SahilCarterr's avatar SahilCarterr Committed by GitHub
Browse files

Added Error when len(gligen_images ) is not equal to len(gligen_phrases) in...

Added Error when len(gligen_images ) is not equal to len(gligen_phrases) in StableDiffusionGLIGENTextImagePipeline (#10176)

* added check value error

* fix style
parent 09675934
...@@ -446,13 +446,14 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM ...@@ -446,13 +446,14 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
height, height,
width, width,
callback_steps, callback_steps,
gligen_images,
gligen_phrases,
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
...@@ -499,6 +500,13 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM ...@@ -499,6 +500,13 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
if gligen_images is not None and gligen_phrases is not None:
if len(gligen_images) != len(gligen_phrases):
raise ValueError(
"`gligen_images` and `gligen_phrases` must have the same length when both are provided, but"
f" got: `gligen_images` with length {len(gligen_images)} != `gligen_phrases` with length {len(gligen_phrases)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = ( shape = (
...@@ -814,6 +822,8 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM ...@@ -814,6 +822,8 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
height, height,
width, width,
callback_steps, callback_steps,
gligen_images,
gligen_phrases,
negative_prompt, negative_prompt,
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment