Unverified Commit 04f4bd54 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] introduce videoprocessor. (#7776)



* introduce videoprocessor.

* fix quality

* address yiyi's feedback

* fix preprocess_video call.

* video_processor -> image_processor

* fix

* fix more.

* quality

* image_processor -> video_processor

* support List[List[PIL.Image.Image]]

* change to video_processor.

* documentation

* Apply suggestions from code review

* changes

* remove print.

* refactor video processor (part # 7776) (#7861)

* update

* update remove deprecate

* Update src/diffusers/video_processor.py

* update

* Apply suggestions from code review

* deprecate list of 5d for video and list of 4d for image + apply other feedbacks

* up

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* add doc.

* tensor2vid -> postprocess_video.

* refactor preprocess with preprocess_video

* set default values.

* empty commit

* more refactoring of prepare_latents in animatediff vid2vid

* checking documentation

* remove documentation for now.

* fix animatediff sdxl

* fix test failure [part of video processor PR] (#7905)

up

* remove preceed_with_frames.

* doc

* fix

* fix

* remove video input as a single-frame video.

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 82be58c5
...@@ -439,6 +439,8 @@ ...@@ -439,6 +439,8 @@
title: Utilities title: Utilities
- local: api/image_processor - local: api/image_processor
title: VAE Image Processor title: VAE Image Processor
- local: api/video_processor
title: Video Processor
title: Internal classes title: Internal classes
isExpanded: false isExpanded: false
title: API title: API
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Video Processor
The `VideoProcessor` provides a unified API for video pipelines to prepare inputs for VAE encoding and post-processing outputs once they're decoded. The class inherits [`VaeImageProcessor`] so it includes transformations such as resizing, normalization, and conversion between PIL Image, PyTorch, and NumPy arrays.
\ No newline at end of file
...@@ -29,15 +29,34 @@ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate ...@@ -29,15 +29,34 @@ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
PipelineImageInput = Union[ PipelineImageInput = Union[
PIL.Image.Image, PIL.Image.Image,
np.ndarray, np.ndarray,
torch.FloatTensor, torch.Tensor,
List[PIL.Image.Image], List[PIL.Image.Image],
List[np.ndarray], List[np.ndarray],
List[torch.FloatTensor], List[torch.Tensor],
] ]
PipelineDepthInput = PipelineImageInput PipelineDepthInput = PipelineImageInput
def is_valid_image(image):
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
def is_valid_image_imagelist(images):
# check if the image input is one of the supported formats for image and image list:
# it can be either one of below 3
# (1) a 4d pytorch tensor or numpy array,
# (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
# (3) a list of valid image
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
return True
elif is_valid_image(images):
return True
elif isinstance(images, list):
return all(is_valid_image(image) for image in images)
return False
class VaeImageProcessor(ConfigMixin): class VaeImageProcessor(ConfigMixin):
""" """
Image processor for VAE. Image processor for VAE.
...@@ -110,7 +129,7 @@ class VaeImageProcessor(ConfigMixin): ...@@ -110,7 +129,7 @@ class VaeImageProcessor(ConfigMixin):
return images return images
@staticmethod @staticmethod
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
""" """
Convert a NumPy image to a PyTorch tensor. Convert a NumPy image to a PyTorch tensor.
""" """
...@@ -121,7 +140,7 @@ class VaeImageProcessor(ConfigMixin): ...@@ -121,7 +140,7 @@ class VaeImageProcessor(ConfigMixin):
return images return images
@staticmethod @staticmethod
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
""" """
Convert a PyTorch tensor to a NumPy image. Convert a PyTorch tensor to a NumPy image.
""" """
...@@ -497,12 +516,27 @@ class VaeImageProcessor(ConfigMixin): ...@@ -497,12 +516,27 @@ class VaeImageProcessor(ConfigMixin):
else: else:
image = np.expand_dims(image, axis=-1) image = np.expand_dims(image, axis=-1)
if isinstance(image, supported_formats): if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
image = [image] warnings.warn(
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): "Passing `image` as a list of 4d np.ndarray is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
FutureWarning,
)
image = np.concatenate(image, axis=0)
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
warnings.warn(
"Passing `image` as a list of 4d torch.Tensor is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
FutureWarning,
)
image = torch.cat(image, axis=0)
if not is_valid_image_imagelist(image):
raise ValueError( raise ValueError(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" f"Input is in incorrect format. Currently, we only support {', '.join(supported_formats)}"
) )
if not isinstance(image, list):
image = [image]
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
if crops_coords is not None: if crops_coords is not None:
...@@ -561,15 +595,15 @@ class VaeImageProcessor(ConfigMixin): ...@@ -561,15 +595,15 @@ class VaeImageProcessor(ConfigMixin):
def postprocess( def postprocess(
self, self,
image: torch.FloatTensor, image: torch.Tensor,
output_type: str = "pil", output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None, do_denormalize: Optional[List[bool]] = None,
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]: ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
""" """
Postprocess the image output from tensor to `output_type`. Postprocess the image output from tensor to `output_type`.
Args: Args:
image (`torch.FloatTensor`): image (`torch.Tensor`):
The image input, should be a pytorch tensor with shape `B x C x H x W`. The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`): output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
...@@ -578,7 +612,7 @@ class VaeImageProcessor(ConfigMixin): ...@@ -578,7 +612,7 @@ class VaeImageProcessor(ConfigMixin):
`VaeImageProcessor` config. `VaeImageProcessor` config.
Returns: Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`: `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
The postprocessed image. The postprocessed image.
""" """
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
...@@ -738,15 +772,15 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): ...@@ -738,15 +772,15 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
def postprocess( def postprocess(
self, self,
image: torch.FloatTensor, image: torch.Tensor,
output_type: str = "pil", output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None, do_denormalize: Optional[List[bool]] = None,
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]: ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
""" """
Postprocess the image output from tensor to `output_type`. Postprocess the image output from tensor to `output_type`.
Args: Args:
image (`torch.FloatTensor`): image (`torch.Tensor`):
The image input, should be a pytorch tensor with shape `B x C x H x W`. The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`): output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
...@@ -755,7 +789,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): ...@@ -755,7 +789,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
`VaeImageProcessor` config. `VaeImageProcessor` config.
Returns: Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`: `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
The postprocessed image. The postprocessed image.
""" """
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
...@@ -793,8 +827,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): ...@@ -793,8 +827,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
def preprocess( def preprocess(
self, self,
rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
target_res: Optional[int] = None, target_res: Optional[int] = None,
...@@ -933,13 +967,13 @@ class IPAdapterMaskProcessor(VaeImageProcessor): ...@@ -933,13 +967,13 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
) )
@staticmethod @staticmethod
def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int): def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
""" """
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
Args: Args:
mask (`torch.FloatTensor`): mask (`torch.Tensor`):
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
batch_size (`int`): batch_size (`int`):
The batch size. The batch size.
...@@ -949,7 +983,7 @@ class IPAdapterMaskProcessor(VaeImageProcessor): ...@@ -949,7 +983,7 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
The dimensionality of the value embeddings. The dimensionality of the value embeddings.
Returns: Returns:
`torch.FloatTensor`: `torch.Tensor`:
The downsampled mask tensor. The downsampled mask tensor.
""" """
......
...@@ -15,11 +15,10 @@ ...@@ -15,11 +15,10 @@
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
...@@ -41,6 +40,7 @@ from ...utils import ( ...@@ -41,6 +40,7 @@ from ...utils import (
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
...@@ -65,27 +65,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -65,27 +65,6 @@ EXAMPLE_DOC_STRING = """
""" """
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
class AnimateDiffPipeline( class AnimateDiffPipeline(
DiffusionPipeline, DiffusionPipeline,
StableDiffusionMixin, StableDiffusionMixin,
...@@ -159,7 +138,7 @@ class AnimateDiffPipeline( ...@@ -159,7 +138,7 @@ class AnimateDiffPipeline(
image_encoder=image_encoder, image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt( def encode_prompt(
...@@ -836,7 +815,7 @@ class AnimateDiffPipeline( ...@@ -836,7 +815,7 @@ class AnimateDiffPipeline(
video = latents video = latents
else: else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models # 10. Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch import torch
from transformers import ( from transformers import (
CLIPImageProcessor, CLIPImageProcessor,
...@@ -25,7 +24,7 @@ from transformers import ( ...@@ -25,7 +24,7 @@ from transformers import (
CLIPVisionModelWithProjection, CLIPVisionModelWithProjection,
) )
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput
from ...loaders import ( from ...loaders import (
FromSingleFileMixin, FromSingleFileMixin,
IPAdapterMixin, IPAdapterMixin,
...@@ -57,6 +56,7 @@ from ...utils import ( ...@@ -57,6 +56,7 @@ from ...utils import (
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
...@@ -113,28 +113,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -113,28 +113,6 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
""" """
...@@ -320,7 +298,7 @@ class AnimateDiffSDXLPipeline( ...@@ -320,7 +298,7 @@ class AnimateDiffSDXLPipeline(
) )
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = self.unet.config.sample_size
...@@ -1291,7 +1269,7 @@ class AnimateDiffSDXLPipeline( ...@@ -1291,7 +1269,7 @@ class AnimateDiffSDXLPipeline(
video = latents video = latents
else: else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# cast back to fp16 if needed # cast back to fp16 if needed
if needs_upcasting: if needs_upcasting:
......
...@@ -15,11 +15,10 @@ ...@@ -15,11 +15,10 @@
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
...@@ -34,6 +33,7 @@ from ...schedulers import ( ...@@ -34,6 +33,7 @@ from ...schedulers import (
) )
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput from .pipeline_output import AnimateDiffPipelineOutput
...@@ -95,28 +95,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -95,28 +95,6 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents( def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
...@@ -264,7 +242,7 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -264,7 +242,7 @@ class AnimateDiffVideoToVideoPipeline(
image_encoder=image_encoder, image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt( def encode_prompt(
...@@ -650,16 +628,7 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -650,16 +628,7 @@ class AnimateDiffVideoToVideoPipeline(
generator, generator,
latents=None, latents=None,
): ):
# video must be a list of list of images
# the outer list denotes having multiple videos as input, whereas inner list means the frames of the video
# as a list of images
if video and not isinstance(video[0], list):
video = [video]
if latents is None: if latents is None:
video = torch.cat(
[self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0
)
video = video.to(device=device, dtype=dtype)
num_frames = video.shape[1] num_frames = video.shape[1]
else: else:
num_frames = latents.shape[2] num_frames = latents.shape[2]
...@@ -943,6 +912,11 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -943,6 +912,11 @@ class AnimateDiffVideoToVideoPipeline(
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
# 5. Prepare latent variables # 5. Prepare latent variables
if latents is None:
video = self.video_processor.preprocess_video(video, height=height, width=width)
# Move the number of frames before the number of channels.
video = video.permute(0, 2, 1, 3, 4)
video = video.to(device=device, dtype=prompt_embeds.dtype)
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
video=video, video=video,
...@@ -1023,7 +997,7 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -1023,7 +997,7 @@ class AnimateDiffVideoToVideoPipeline(
video = latents video = latents
else: else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models # 10. Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
......
...@@ -31,6 +31,7 @@ from ...utils import ( ...@@ -31,6 +31,7 @@ from ...utils import (
replace_example_docstring, replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...@@ -70,28 +71,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -70,28 +71,6 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
@dataclass @dataclass
class I2VGenXLPipelineOutput(BaseOutput): class I2VGenXLPipelineOutput(BaseOutput):
r""" r"""
...@@ -156,7 +135,7 @@ class I2VGenXLPipeline( ...@@ -156,7 +135,7 @@ class I2VGenXLPipeline(
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# `do_resize=False` as we do custom resizing. # `do_resize=False` as we do custom resizing.
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False)
@property @property
def guidance_scale(self): def guidance_scale(self):
...@@ -342,8 +321,8 @@ class I2VGenXLPipeline( ...@@ -342,8 +321,8 @@ class I2VGenXLPipeline(
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.image_processor.pil_to_numpy(image) image = self.video_processor.pil_to_numpy(image)
image = self.image_processor.numpy_to_pt(image) image = self.video_processor.numpy_to_pt(image)
# Normalize the image with CLIP training stats. # Normalize the image with CLIP training stats.
image = self.feature_extractor( image = self.feature_extractor(
...@@ -657,7 +636,7 @@ class I2VGenXLPipeline( ...@@ -657,7 +636,7 @@ class I2VGenXLPipeline(
# 3.2.2 Image latents. # 3.2.2 Image latents.
resized_image = _center_crop_wide(image, (width, height)) resized_image = _center_crop_wide(image, (width, height))
image = self.image_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype) image = self.video_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype)
image_latents = self.prepare_image_latents( image_latents = self.prepare_image_latents(
image, image,
device=device, device=device,
...@@ -737,7 +716,7 @@ class I2VGenXLPipeline( ...@@ -737,7 +716,7 @@ class I2VGenXLPipeline(
video = latents video = latents
else: else:
video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 9. Offload all models # 9. Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
......
...@@ -21,7 +21,7 @@ import PIL ...@@ -21,7 +21,7 @@ import PIL
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
...@@ -43,6 +43,7 @@ from ...utils import ( ...@@ -43,6 +43,7 @@ from ...utils import (
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...@@ -89,28 +90,6 @@ RANGE_LIST = [ ...@@ -89,28 +90,6 @@ RANGE_LIST = [
] ]
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_scale: int): def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_scale: int):
assert num_frames > 0, "video_length should be greater than 0" assert num_frames > 0, "video_length should be greater than 0"
...@@ -218,7 +197,7 @@ class PIAPipeline( ...@@ -218,7 +197,7 @@ class PIAPipeline(
image_encoder=image_encoder, image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt( def encode_prompt(
...@@ -621,7 +600,7 @@ class PIAPipeline( ...@@ -621,7 +600,7 @@ class PIAPipeline(
) )
_, _, _, scaled_height, scaled_width = shape _, _, _, scaled_height, scaled_width = shape
image = self.image_processor.preprocess(image) image = self.video_processor.preprocess(image)
image = image.to(device, dtype) image = image.to(device, dtype)
if isinstance(generator, list): if isinstance(generator, list):
...@@ -959,7 +938,7 @@ class PIAPipeline( ...@@ -959,7 +938,7 @@ class PIAPipeline(
video = latents video = latents
else: else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models # 10. Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
......
...@@ -21,11 +21,12 @@ import PIL.Image ...@@ -21,11 +21,12 @@ import PIL.Image
import torch import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler from ...schedulers import EulerDiscreteScheduler
from ...utils import BaseOutput, logging, replace_example_docstring from ...utils import BaseOutput, logging, replace_example_docstring
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -61,28 +62,6 @@ def _append_dims(x, target_dims): ...@@ -61,28 +62,6 @@ def _append_dims(x, target_dims):
return x[(...,) + (None,) * dims_to_append] return x[(...,) + (None,) * dims_to_append]
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps( def retrieve_timesteps(
scheduler, scheduler,
...@@ -199,7 +178,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -199,7 +178,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
def _encode_image( def _encode_image(
self, self,
...@@ -211,8 +190,8 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -211,8 +190,8 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
image = self.image_processor.pil_to_numpy(image) image = self.video_processor.pil_to_numpy(image)
image = self.image_processor.numpy_to_pt(image) image = self.video_processor.numpy_to_pt(image)
# We normalize the image before resizing to match with the original implementation. # We normalize the image before resizing to match with the original implementation.
# Then we unnormalize it after resizing. # Then we unnormalize it after resizing.
...@@ -520,7 +499,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -520,7 +499,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
fps = fps - 1 fps = fps - 1
# 4. Encode input image using VAE # 4. Encode input image using VAE
image = self.image_processor.preprocess(image, height=height, width=width).to(device) image = self.video_processor.preprocess(image, height=height, width=width).to(device)
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
image = image + noise_aug_strength * noise image = image + noise_aug_strength * noise
...@@ -626,7 +605,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -626,7 +605,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
if needs_upcasting: if needs_upcasting:
self.vae.to(dtype=torch.float16) self.vae.to(dtype=torch.float16)
frames = self.decode_latents(latents, num_frames, decode_chunk_size) frames = self.decode_latents(latents, num_frames, decode_chunk_size)
frames = tensor2vid(frames, self.image_processor, output_type=output_type) frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
else: else:
frames = latents frames = latents
......
...@@ -15,11 +15,9 @@ ...@@ -15,11 +15,9 @@
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
...@@ -33,6 +31,7 @@ from ...utils import ( ...@@ -33,6 +31,7 @@ from ...utils import (
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import TextToVideoSDPipelineOutput from . import TextToVideoSDPipelineOutput
...@@ -59,28 +58,6 @@ EXAMPLE_DOC_STRING = """ ...@@ -59,28 +58,6 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin): class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin):
r""" r"""
Pipeline for text-to-video generation. Pipeline for text-to-video generation.
...@@ -127,7 +104,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInve ...@@ -127,7 +104,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInve
scheduler=scheduler, scheduler=scheduler,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt( def _encode_prompt(
...@@ -652,7 +629,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInve ...@@ -652,7 +629,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInve
video = latents video = latents
else: else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 9. Offload all models # 9. Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
......
...@@ -16,11 +16,9 @@ import inspect ...@@ -16,11 +16,9 @@ import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL.Image
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
...@@ -34,6 +32,7 @@ from ...utils import ( ...@@ -34,6 +32,7 @@ from ...utils import (
unscale_lora_layers, unscale_lora_layers,
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import TextToVideoSDPipelineOutput from . import TextToVideoSDPipelineOutput
...@@ -94,69 +93,6 @@ def retrieve_latents( ...@@ -94,69 +93,6 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output") raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
def preprocess_video(video):
supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image)
if isinstance(video, supported_formats):
video = [video]
elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}"
)
if isinstance(video[0], PIL.Image.Image):
video = [np.array(frame) for frame in video]
if isinstance(video[0], np.ndarray):
video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0)
if video.dtype == np.uint8:
video = np.array(video).astype(np.float32) / 255.0
if video.ndim == 4:
video = video[None, ...]
video = torch.from_numpy(video.transpose(0, 4, 1, 2, 3))
elif isinstance(video[0], torch.Tensor):
video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0)
# don't need any preprocess if the video is latents
channel = video.shape[1]
if channel == 4:
return video
# move channels before num_frames
video = video.permute(0, 2, 1, 3, 4)
# normalize video
video = 2.0 * video - 1.0
return video
class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin): class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin):
r""" r"""
Pipeline for text-guided video-to-video generation. Pipeline for text-guided video-to-video generation.
...@@ -203,7 +139,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv ...@@ -203,7 +139,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
scheduler=scheduler, scheduler=scheduler,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt( def _encode_prompt(
...@@ -687,7 +623,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv ...@@ -687,7 +623,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Preprocess video # 4. Preprocess video
video = preprocess_video(video) video = self.video_processor.preprocess_video(video)
# 5. Prepare timesteps # 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
...@@ -749,7 +685,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv ...@@ -749,7 +685,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
video = latents video = latents
else: else:
video_tensor = self.decode_latents(latents) video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models # 10. Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List, Optional, Union
import numpy as np
import PIL
import torch
from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
class VideoProcessor(VaeImageProcessor):
r"""Simple video processor."""
def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor:
r"""
Preprocesses input video(s).
Args:
video: The input video. It can be one of the following:
* List of the PIL images.
* List of list of PIL images.
* 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)).
* 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)).
* List of 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)).
* List of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)).
* 5D NumPy arrays: expected shape for each array: (batch_size, num_frames, height, width,
num_channels).
* 5D Torch tensors: expected shape for each array: (batch_size, num_frames, num_channels, height,
width).
height (`int`, *optional*, defaults to `None`):
The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to
get default height.
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get
the default width.
"""
if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5:
warnings.warn(
"Passing `video` as a list of 5d np.ndarray is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray",
FutureWarning,
)
video = np.concatenate(video, axis=0)
if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5:
warnings.warn(
"Passing `video` as a list of 5d torch.Tensor is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor",
FutureWarning,
)
video = torch.cat(video, axis=0)
# ensure the input is a list of videos:
# - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray)
# - if it is is a single video, it is convereted to a list of one video.
if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5:
video = list(video)
elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video):
video = [video]
elif isinstance(video, list) and is_valid_image_imagelist(video[0]):
video = video
else:
raise ValueError(
"Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image"
)
video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0)
# move the number of channels before the number of frames.
video = video.permute(0, 2, 1, 3, 4)
return video
def postprocess_video(
self, video: torch.Tensor, output_type: str = "np"
) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]:
r"""
Converts a video tensor to a list of frames for export.
Args:
video (`torch.Tensor`): The video as a tensor.
output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor.
"""
batch_size = video.shape[0]
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = self.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import PIL.Image
import torch
from parameterized import parameterized
from diffusers.video_processor import VideoProcessor
np.random.seed(0)
torch.manual_seed(0)
class VideoProcessorTest(unittest.TestCase):
def get_dummy_sample(self, input_type):
batch_size = 1
num_frames = 5
num_channels = 3
height = 8
width = 8
def generate_image():
return PIL.Image.fromarray(np.random.randint(0, 256, size=(height, width, num_channels)).astype("uint8"))
def generate_4d_array():
return np.random.rand(num_frames, height, width, num_channels)
def generate_5d_array():
return np.random.rand(batch_size, num_frames, height, width, num_channels)
def generate_4d_tensor():
return torch.rand(num_frames, num_channels, height, width)
def generate_5d_tensor():
return torch.rand(batch_size, num_frames, num_channels, height, width)
if input_type == "list_images":
sample = [generate_image() for _ in range(num_frames)]
elif input_type == "list_list_images":
sample = [[generate_image() for _ in range(num_frames)] for _ in range(num_frames)]
elif input_type == "list_4d_np":
sample = [generate_4d_array() for _ in range(num_frames)]
elif input_type == "list_list_4d_np":
sample = [[generate_4d_array() for _ in range(num_frames)] for _ in range(num_frames)]
elif input_type == "list_5d_np":
sample = [generate_5d_array() for _ in range(num_frames)]
elif input_type == "5d_np":
sample = generate_5d_array()
elif input_type == "list_4d_pt":
sample = [generate_4d_tensor() for _ in range(num_frames)]
elif input_type == "list_list_4d_pt":
sample = [[generate_4d_tensor() for _ in range(num_frames)] for _ in range(num_frames)]
elif input_type == "list_5d_pt":
sample = [generate_5d_tensor() for _ in range(num_frames)]
elif input_type == "5d_pt":
sample = generate_5d_tensor()
return sample
def to_np(self, video):
# List of images.
if isinstance(video[0], PIL.Image.Image):
video = np.stack([np.array(i) for i in video], axis=0)
# List of list of images.
elif isinstance(video, list) and isinstance(video[0][0], PIL.Image.Image):
frames = []
for vid in video:
all_current_frames = np.stack([np.array(i) for i in vid], axis=0)
frames.append(all_current_frames)
video = np.stack([np.array(frame) for frame in frames], axis=0)
# List of 4d/5d {ndarrays, torch tensors}.
elif isinstance(video, list) and isinstance(video[0], (torch.Tensor, np.ndarray)):
if isinstance(video[0], np.ndarray):
video = np.stack(video, axis=0) if video[0].ndim == 4 else np.concatenate(video, axis=0)
else:
if video[0].ndim == 4:
video = np.stack([i.cpu().numpy().transpose(0, 2, 3, 1) for i in video], axis=0)
elif video[0].ndim == 5:
video = np.concatenate([i.cpu().numpy().transpose(0, 1, 3, 4, 2) for i in video], axis=0)
# List of list of 4d/5d {ndarrays, torch tensors}.
elif (
isinstance(video, list)
and isinstance(video[0], list)
and isinstance(video[0][0], (torch.Tensor, np.ndarray))
):
all_frames = []
for list_of_videos in video:
temp_frames = []
for vid in list_of_videos:
if vid.ndim == 4:
current_vid_frames = np.stack(
[i if isinstance(i, np.ndarray) else i.cpu().numpy().transpose(1, 2, 0) for i in vid],
axis=0,
)
elif vid.ndim == 5:
current_vid_frames = np.concatenate(
[i if isinstance(i, np.ndarray) else i.cpu().numpy().transpose(0, 2, 3, 1) for i in vid],
axis=0,
)
temp_frames.append(current_vid_frames)
temp_frames = np.stack(temp_frames, axis=0)
all_frames.append(temp_frames)
video = np.concatenate(all_frames, axis=0)
# Just 5d {ndarrays, torch tensors}.
elif isinstance(video, (torch.Tensor, np.ndarray)) and video.ndim == 5:
video = video if isinstance(video, np.ndarray) else video.cpu().numpy().transpose(0, 1, 3, 4, 2)
return video
@parameterized.expand(["list_images", "list_list_images"])
def test_video_processor_pil(self, input_type):
video_processor = VideoProcessor(do_resize=False, do_normalize=True)
input = self.get_dummy_sample(input_type=input_type)
for output_type in ["pt", "np", "pil"]:
out = video_processor.postprocess_video(video_processor.preprocess_video(input), output_type=output_type)
out_np = self.to_np(out)
input_np = self.to_np(input).astype("float32") / 255.0 if output_type != "pil" else self.to_np(input)
assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}"
@parameterized.expand(["list_4d_np", "list_5d_np", "5d_np"])
def test_video_processor_np(self, input_type):
video_processor = VideoProcessor(do_resize=False, do_normalize=True)
input = self.get_dummy_sample(input_type=input_type)
for output_type in ["pt", "np", "pil"]:
out = video_processor.postprocess_video(video_processor.preprocess_video(input), output_type=output_type)
out_np = self.to_np(out)
input_np = (
(self.to_np(input) * 255.0).round().astype("uint8") if output_type == "pil" else self.to_np(input)
)
assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}"
@parameterized.expand(["list_4d_pt", "list_5d_pt", "5d_pt"])
def test_video_processor_pt(self, input_type):
video_processor = VideoProcessor(do_resize=False, do_normalize=True)
input = self.get_dummy_sample(input_type=input_type)
for output_type in ["pt", "np", "pil"]:
out = video_processor.postprocess_video(video_processor.preprocess_video(input), output_type=output_type)
out_np = self.to_np(out)
input_np = (
(self.to_np(input) * 255.0).round().astype("uint8") if output_type == "pil" else self.to_np(input)
)
assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment