Unverified Commit fa1bdce3 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[docs] Improve SVD pipeline docs (#7087)



* update svd docs

* fix example doc string

* update return type hints/docs

* update type hints

* Fix typos in pipeline_stable_video_diffusion.py

* make style && make fix-copies

* Update src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update based on suggestion

---------
Co-authored-by: default avatarM. Tolga Cangöz <mtcangoz@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent ca6cdc77
...@@ -21,16 +21,33 @@ import PIL.Image ...@@ -21,16 +21,33 @@ import PIL.Image
import torch import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler from ...schedulers import EulerDiscreteScheduler
from ...utils import BaseOutput, logging from ...utils import BaseOutput, logging, replace_example_docstring
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers import StableVideoDiffusionPipeline
>>> from diffusers.utils import load_image, export_to_video
>>> pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
>>> pipe.to("cuda")
>>> image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg")
>>> image = image.resize((1024, 576))
>>> frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
>>> export_to_video(frames, "generated.mp4", fps=7)
```
"""
def _append_dims(x, target_dims): def _append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
...@@ -41,7 +58,7 @@ def _append_dims(x, target_dims): ...@@ -41,7 +58,7 @@ def _append_dims(x, target_dims):
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape batch_size, channels, num_frames, height, width = video.shape
outputs = [] outputs = []
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
...@@ -65,15 +82,15 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: ...@@ -65,15 +82,15 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
@dataclass @dataclass
class StableVideoDiffusionPipelineOutput(BaseOutput): class StableVideoDiffusionPipelineOutput(BaseOutput):
r""" r"""
Output class for zero-shot text-to-video pipeline. Output class for Stable Video Diffusion pipeline.
Args: Args:
frames (`[List[PIL.Image.Image]`, `np.ndarray`]): frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.FloatTensor`]):
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, List of denoised PIL images of length `batch_size` or numpy array or torch tensor
num_channels)`. of shape `(batch_size, num_frames, height, width, num_channels)`.
""" """
frames: Union[List[PIL.Image.Image], np.ndarray] frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.FloatTensor]
class StableVideoDiffusionPipeline(DiffusionPipeline): class StableVideoDiffusionPipeline(DiffusionPipeline):
...@@ -119,7 +136,13 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -119,7 +136,13 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance): def _encode_image(
self,
image: PipelineImageInput,
device: Union[str, torch.device],
num_videos_per_prompt: int,
do_classifier_free_guidance: bool,
) -> torch.FloatTensor:
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
...@@ -164,9 +187,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -164,9 +187,9 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
def _encode_vae_image( def _encode_vae_image(
self, self,
image: torch.Tensor, image: torch.Tensor,
device, device: Union[str, torch.device],
num_videos_per_prompt, num_videos_per_prompt: int,
do_classifier_free_guidance, do_classifier_free_guidance: bool,
): ):
image = image.to(device=device) image = image.to(device=device)
image_latents = self.vae.encode(image).latent_dist.mode() image_latents = self.vae.encode(image).latent_dist.mode()
...@@ -186,13 +209,13 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -186,13 +209,13 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
def _get_add_time_ids( def _get_add_time_ids(
self, self,
fps, fps: int,
motion_bucket_id, motion_bucket_id: int,
noise_aug_strength, noise_aug_strength: float,
dtype, dtype: torch.dtype,
batch_size, batch_size: int,
num_videos_per_prompt, num_videos_per_prompt: int,
do_classifier_free_guidance, do_classifier_free_guidance: bool,
): ):
add_time_ids = [fps, motion_bucket_id, noise_aug_strength] add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
...@@ -212,7 +235,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -212,7 +235,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
return add_time_ids return add_time_ids
def decode_latents(self, latents, num_frames, decode_chunk_size=14): def decode_latents(self, latents: torch.FloatTensor, num_frames: int, decode_chunk_size: int = 14):
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
latents = latents.flatten(0, 1) latents = latents.flatten(0, 1)
...@@ -257,15 +280,15 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -257,15 +280,15 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
def prepare_latents( def prepare_latents(
self, self,
batch_size, batch_size: int,
num_frames, num_frames: int,
num_channels_latents, num_channels_latents: int,
height, height: int,
width, width: int,
dtype, dtype: torch.dtype,
device, device: Union[str, torch.device],
generator, generator: torch.Generator,
latents=None, latents: Optional[torch.FloatTensor] = None,
): ):
shape = ( shape = (
batch_size, batch_size,
...@@ -307,6 +330,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -307,6 +330,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
return self._num_timesteps return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
...@@ -333,15 +357,16 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -333,15 +357,16 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
Args: Args:
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
Image or images to guide image generation. If you provide a tensor, the expected value range is between `[0,1]`. Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0, 1]`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image. The width in pixels of the generated image.
num_frames (`int`, *optional*): num_frames (`int`, *optional*):
The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt` The number of video frames to generate. Defaults to `self.unet.config.num_frames`
(14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
num_inference_steps (`int`, *optional*, defaults to 25): num_inference_steps (`int`, *optional*, defaults to 25):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality video at the
expense of slower inference. This parameter is modulated by `strength`. expense of slower inference. This parameter is modulated by `strength`.
min_guidance_scale (`float`, *optional*, defaults to 1.0): min_guidance_scale (`float`, *optional*, defaults to 1.0):
The minimum guidance scale. Used for the classifier free guidance with first frame. The minimum guidance scale. Used for the classifier free guidance with first frame.
...@@ -351,29 +376,29 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -351,29 +376,29 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
Frames per second. The rate at which the generated images shall be exported to a video after generation. Frames per second. The rate at which the generated images shall be exported to a video after generation.
Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
motion_bucket_id (`int`, *optional*, defaults to 127): motion_bucket_id (`int`, *optional*, defaults to 127):
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. Used for conditioning the amount of motion for the generation. The higher the number the more motion
will be in the video.
noise_aug_strength (`float`, *optional*, defaults to 0.02): noise_aug_strength (`float`, *optional*, defaults to 0.02):
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
decode_chunk_size (`int`, *optional*): decode_chunk_size (`int`, *optional*):
The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the expense of more memory usage. By default, the decoder decodes all frames at once for maximal
between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once quality. For lower memory usage, reduce `decode_chunk_size`.
for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
num_videos_per_prompt (`int`, *optional*, defaults to 1): num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic. generation deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. tensor is generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `pil`, `np` or `pt`.
callback_on_step_end (`Callable`, *optional*): callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called A function that is called at the end of each denoising step during inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, with the following arguments:
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
`callback_on_step_end_tensor_inputs`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*): callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
...@@ -382,26 +407,12 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -382,26 +407,12 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple. plain tuple.
Examples:
Returns: Returns:
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list of list with the generated frames. otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`) is returned.
Examples:
```py
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
pipe.to("cuda")
image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
image = image.resize((1024, 576))
frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
export_to_video(frames, "generated.mp4", fps=7)
```
""" """
# 0. Default height and width to unet # 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor height = height or self.unet.config.sample_size * self.vae_scale_factor
...@@ -429,8 +440,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -429,8 +440,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
# 3. Encode input image # 3. Encode input image
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance) image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
# is why it is reduced here.
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
fps = fps - 1 fps = fps - 1
...@@ -471,11 +481,11 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -471,11 +481,11 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
) )
added_time_ids = added_time_ids.to(device) added_time_ids = added_time_ids.to(device)
# 4. Prepare timesteps # 6. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 7. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_videos_per_prompt, batch_size * num_videos_per_prompt,
...@@ -489,7 +499,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -489,7 +499,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
latents, latents,
) )
# 7. Prepare guidance scale # 8. Prepare guidance scale
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
guidance_scale = guidance_scale.to(device, latents.dtype) guidance_scale = guidance_scale.to(device, latents.dtype)
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
...@@ -497,7 +507,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -497,7 +507,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
# 8. Denoising loop # 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
...@@ -506,7 +516,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -506,7 +516,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# Concatenate image_latents over channels dimention # Concatenate image_latents over channels dimension
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
# predict the noise residual # predict the noise residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment