Unverified Commit 6281d206 authored by Carson Katri's avatar Carson Katri Committed by GitHub
Browse files

Add callbacks to `WuerstchenDecoderPipeline` and `WuerstchenCombinedPipeline` (#5154)

parent 28254c79
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -202,6 +202,8 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -202,6 +202,8 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -240,6 +242,12 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -240,6 +242,12 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
(`np.array`) or `"pt"` (`torch.Tensor`). (`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples: Examples:
...@@ -315,7 +323,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -315,7 +323,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler) latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler)
# 6. Run denoising loop # 6. Run denoising loop
for t in self.progress_bar(timesteps[:-1]): for i, t in enumerate(self.progress_bar(timesteps[:-1])):
ratio = t.expand(latents.size(0)).to(dtype) ratio = t.expand(latents.size(0)).to(dtype)
effnet = ( effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
...@@ -343,6 +351,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -343,6 +351,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
generator=generator, generator=generator,
).prev_sample ).prev_sample
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 10. Scale and decode the image latents with vq-vae # 10. Scale and decode the image latents with vq-vae
latents = self.vqgan.config.scale_factor * latents latents = self.vqgan.config.scale_factor * latents
images = self.vqgan.decode(latents).sample.clamp(0, 1) images = self.vqgan.decode(latents).sample.clamp(0, 1)
......
...@@ -161,6 +161,10 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -161,6 +161,10 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
prior_callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
prior_callback_steps: int = 1,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -222,6 +226,18 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -222,6 +226,18 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
(`np.array`) or `"pt"` (`torch.Tensor`). (`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
prior_callback (`Callable`, *optional*):
A function that will be called every `prior_callback_steps` steps during inference. The function will be
called with the following arguments: `prior_callback(step: int, timestep: int, latents: torch.FloatTensor)`.
prior_callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples: Examples:
...@@ -244,6 +260,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -244,6 +260,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
latents=latents, latents=latents,
output_type="pt", output_type="pt",
return_dict=False, return_dict=False,
callback=prior_callback,
callback_steps=prior_callback_steps,
) )
image_embeddings = prior_outputs[0] image_embeddings = prior_outputs[0]
...@@ -257,6 +275,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline): ...@@ -257,6 +275,8 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
generator=generator, generator=generator,
output_type=output_type, output_type=output_type,
return_dict=return_dict, return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
) )
return outputs return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment