Unverified Commit 93044f89 authored by machineminded's avatar machineminded Committed by GitHub
Browse files

add capability to interrupt the pipeline (#119)



* add capability to interrupt the pipeline

* Update pipeline.py

---------
Co-authored-by: default avatarZhen Li <zhenli1031@gmail.com>
parent 244f4761
...@@ -216,6 +216,9 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): ...@@ -216,6 +216,9 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
return prompt_embeds, pooled_prompt_embeds, class_tokens_mask return prompt_embeds, pooled_prompt_embeds, class_tokens_mask
@property
def interrupt(self):
return self._interrupt
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
...@@ -246,6 +249,8 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): ...@@ -246,6 +249,8 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
target_size: Optional[Tuple[int, int]] = None, target_size: Optional[Tuple[int, int]] = None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1, callback_steps: int = 1,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
# Added parameters (for PhotoMaker) # Added parameters (for PhotoMaker)
input_id_images: PipelineImageInput = None, input_id_images: PipelineImageInput = None,
start_merge_step: int = 0, # TODO: change to `style_strength_ratio` in the future start_merge_step: int = 0, # TODO: change to `style_strength_ratio` in the future
...@@ -295,7 +300,11 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): ...@@ -295,7 +300,11 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
negative_prompt_embeds, negative_prompt_embeds,
pooled_prompt_embeds, pooled_prompt_embeds,
negative_pooled_prompt_embeds, negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs,
) )
self._interrupt = False
# #
if prompt_embeds is not None and class_tokens_mask is None: if prompt_embeds is not None and class_tokens_mask is None:
raise ValueError( raise ValueError(
...@@ -426,6 +435,9 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): ...@@ -426,6 +435,9 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = ( latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents torch.cat([latents] * 2) if do_classifier_free_guidance else latents
) )
...@@ -464,11 +476,28 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): ...@@ -464,11 +476,28 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
# negative_pooled_prompt_embeds = callback_outputs.pop(
# "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
# )
# add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
# negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, latents) step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# make sure the VAE is in float32 mode, as it overflows in float16 # make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
...@@ -487,9 +516,8 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): ...@@ -487,9 +516,8 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
image = self.image_processor.postprocess(image, output_type=output_type) image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU # Offload all models
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.maybe_free_model_hooks()
self.final_offload_hook.offload()
if not return_dict: if not return_dict:
return (image,) return (image,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment