Unverified Commit f1f542bd authored by Haofan Wang's avatar Haofan Wang Committed by GitHub
Browse files

Update pipeline_stable_diffusion_3_controlnet.py (#8660)


Co-authored-by: default avatarYiYi Xu <yixu310@gmail,com>
parent a9c403c0
...@@ -513,6 +513,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -513,6 +513,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
pooled_prompt_embeds=None, pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
...@@ -584,6 +585,9 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -584,6 +585,9 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
) )
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
def prepare_latents( def prepare_latents(
self, self,
...@@ -710,6 +714,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -710,6 +714,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
clip_skip: Optional[int] = None, clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -811,6 +816,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -811,6 +816,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples: Examples:
...@@ -850,6 +856,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -850,6 +856,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
...@@ -888,6 +895,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, ...@@ -888,6 +895,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
device=device, device=device,
clip_skip=self.clip_skip, clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
) )
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment