Unverified Commit bdbcaa98 authored by rafael's avatar rafael Committed by GitHub
Browse files

lpw_stable_diffusion: Add is_cancelled_callback (#1053)

* [Community Pipelines] lpw_stable_diffusion: Add is_cancelled_callback

* [Community pipelines] lpw_stable_diffusion_onnx: Add is_cancelled_callback
parent 8ee21915
...@@ -498,6 +498,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): ...@@ -498,6 +498,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs, **kwargs,
): ):
...@@ -560,11 +561,15 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): ...@@ -560,11 +561,15 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
callback (`Callable`, *optional*): callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
Returns: Returns:
`None` if cancelled by `is_cancelled_callback`,
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a When returning a tuple, the first element is a list with the generated images, and the second element is a
...@@ -757,8 +762,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): ...@@ -757,8 +762,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
latents = (init_latents_proper * mask) + (latents * (1 - mask)) latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided # call the callback, if provided
if callback is not None and i % callback_steps == 0: if i % callback_steps == 0:
callback(i, t, latents) if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample image = self.vae.decode(latents).sample
......
...@@ -435,6 +435,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): ...@@ -435,6 +435,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
**kwargs, **kwargs,
): ):
...@@ -496,11 +497,15 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): ...@@ -496,11 +497,15 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
callback (`Callable`, *optional*): callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
is_cancelled_callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. If the function returns
`True`, the inference will be cancelled.
callback_steps (`int`, *optional*, defaults to 1): callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step. called at every step.
Returns: Returns:
`None` if cancelled by `is_cancelled_callback`,
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a When returning a tuple, the first element is a list with the generated images, and the second element is a
...@@ -668,8 +673,11 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): ...@@ -668,8 +673,11 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
latents = (init_latents_proper * mask) + (latents * (1 - mask)) latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided # call the callback, if provided
if callback is not None and i % callback_steps == 0: if i % callback_steps == 0:
callback(i, t, latents) if callback is not None:
callback(i, t, latents)
if is_cancelled_callback is not None and is_cancelled_callback():
return None
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0] # image = self.vae_decoder(latent_sample=latents)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment