Unverified Commit ba49272d authored by Andranik Movsisyan's avatar Andranik Movsisyan Committed by GitHub
Browse files

[Pipeline] Add TextToVideoZeroPipeline (#2954)



* add TextToVideoZeroPipeline and CrossFrameAttnProcessor

* add docs for text-to-video zero

* add teaser image for text-to-video zero docs

* Fix review changes. Add Documentation. Add test

* clean up the codes in pipeline_text_to_video.py. Add descriptive comments and docstrings

* make style && make quality

* make fix-copies

* make requested changes to docs. use huggingface server links for resources, delete res folder

* make style && make quality && make fix-copies

* make style && make quality

* Apply suggestions from code review

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 074d281a
......@@ -206,6 +206,8 @@
title: Stochastic Karras VE
- local: api/pipelines/text_to_video
title: Text-to-Video
- local: api/pipelines/text_to_video_zero
title: Text-to-Video Zero
- local: api/pipelines/unclip
title: UnCLIP
- local: api/pipelines/latent_diffusion_uncond
......
......@@ -83,6 +83,7 @@ available a colab notebook to directly try them out.
| [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
| [versatile_diffusion](./versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
| [vq_diffusion](./vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
| [text_to_video_zero](./text_to_video_zero) | [Text2Video-Zero: Text-to-Image Diffusion Models are Zero-Shot Video Generators](https://arxiv.org/abs/2303.13439) | Text-to-Video Generation |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Zero-Shot Text-to-Video Generation
## Overview
[Text2Video-Zero: Text-to-Image Diffusion Models are Zero-Shot Video Generators](https://arxiv.org/abs/2303.13439) by
Levon Khachatryan,
Andranik Movsisyan,
Vahram Tadevosyan,
Roberto Henschel,
[Zhangyang Wang](https://www.ece.utexas.edu/people/faculty/atlas-wang), Shant Navasardyan, [Humphrey Shi](https://www.humphreyshi.com).
Our method Text2Video-Zero enables zero-shot video generation using either
1. A textual prompt, or
2. A prompt combined with guidance from poses or edges, or
3. Video Instruct-Pix2Pix, i.e., instruction-guided video editing.
Results are temporally consistent and follow closely the guidance and textual prompts.
![teaser-img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2v_zero_teaser.png)
The abstract of the paper is the following:
*Recent text-to-video generation approaches rely on computationally heavy training and require large-scale video datasets. In this paper, we introduce a new task of zero-shot text-to-video generation and propose a low-cost approach (without any training or optimization) by leveraging the power of existing text-to-image synthesis methods (e.g., Stable Diffusion), making them suitable for the video domain.
Our key modifications include (i) enriching the latent codes of the generated frames with motion dynamics to keep the global scene and the background time consistent; and (ii) reprogramming frame-level self-attention using a new cross-frame attention of each frame on the first frame, to preserve the context, appearance, and identity of the foreground object.
Experiments show that this leads to low overhead, yet high-quality and remarkably consistent video generation. Moreover, our approach is not limited to text-to-video synthesis but is also applicable to other tasks such as conditional and content-specialized video generation, and Video Instruct-Pix2Pix, i.e., instruction-guided video editing.
As experiments show, our method performs comparably or sometimes better than recent approaches, despite not being trained on additional video data.*
Resources:
* [Project Page](https://text2video-zero.github.io/)
* [Paper](https://arxiv.org/abs/2303.13439)
* [Original Code](https://github.com/Picsart-AI-Research/Text2Video-Zero)
## Available Pipelines:
| Pipeline | Tasks | Demo
|---|---|:---:|
| [TextToVideoZeroPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py) | *Zero-shot Text-to-Video Generation* | [🤗 Space](https://huggingface.co/spaces/PAIR/Text2Video-Zero)
## Usage example
### Text-To-Video
To generate a video from prompt, run the following python command
```python
import torch
from diffusers import TextToVideoZeroPipeline
model_id = "runwayml/stable-diffusion-v1-5"
pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
prompt = "A panda is playing guitar on times square"
result = pipe(prompt=prompt).images
imageio.mimsave("video.mp4", result, fps=4)
```
You can change these parameters in the pipeline call:
* Motion field strength (see the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1):
* `motion_field_strength_x` and `motion_field_strength_y`. Default: `motion_field_strength_x=12`, `motion_field_strength_y=12`
* `T` and `T'` (see the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1)
* `t0` and `t1` in the range `{0, ..., num_inference_steps}`. Default: `t0=45`, `t1=48`
* Video length:
* `video_length`, the number of frames video_length to be generated. Default: `video_length=8`
### Text-To-Video with Pose Control
To generate a video from prompt with additional pose control
1. Download a demo video
```python
from huggingface_hub import hf_hub_download
filename = "__assets__/poses_skeleton_gifs/dance1_corr.mp4"
repo_id = "PAIR/Text2Video-Zero"
video_path = hf_hub_download(repo_type="space", repo_id=repo_id, filename=filename)
```
2. Read video containing extracted pose images
```python
import imageio
reader = imageio.get_reader(video_path, "ffmpeg")
frame_count = 8
pose_images = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
```
To extract pose from actual video, read [ControlNet documentation](./stable_diffusion/controlnet).
3. Run `StableDiffusionControlNetPipeline` with our custom attention processor
```python
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
model_id = "runwayml/stable-diffusion-v1-5"
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_id, controlnet=controlnet, torch_dtype=torch.float16
).to("cuda")
# Set the attention processor
pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
# fix latents for all frames
latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)
prompt = "Darth Vader dancing in a desert"
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
imageio.mimsave("video.mp4", result, fps=4)
```
### Text-To-Video with Edge Control
To generate a video from prompt with additional pose control,
follow the steps described above for pose-guided generation using [Canny edge ControlNet model](https://huggingface.co/lllyasviel/sd-controlnet-canny).
### Video Instruct-Pix2Pix
To perform text-guided video editing (with [InstructPix2Pix](./stable_diffusion/pix2pix)):
1. Download a demo video
```python
from huggingface_hub import hf_hub_download
filename = "__assets__/pix2pix video/camel.mp4"
repo_id = "PAIR/Text2Video-Zero"
video_path = hf_hub_download(repo_type="space", repo_id=repo_id, filename=filename)
```
2. Read video from path
```python
import imageio
reader = imageio.get_reader(video_path, "ffmpeg")
frame_count = 8
video = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
```
3. Run `StableDiffusionInstructPix2PixPipeline` with our custom attention processor
```python
import torch
from diffusers import StableDiffusionInstructPix2PixPipeline
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
model_id = "timbrooks/instruct-pix2pix"
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=3))
prompt = "make it Van Gogh Starry Night style"
result = pipe(prompt=[prompt] * len(video), image=video).images
imageio.mimsave("edited_video.mp4", result, fps=4)
```
### Dreambooth specialization
Methods **Text-To-Video**, **Text-To-Video with Pose Control** and **Text-To-Video with Edge Control**
can run with custom [DreamBooth](../training/dreambooth) models, as shown below for
[Canny edge ControlNet model](https://huggingface.co/lllyasviel/sd-controlnet-canny) and
[Avatar style DreamBooth](https://huggingface.co/PAIR/text2video-zero-controlnet-canny-avatar) model
1. Download demo video from huggingface
```python
from huggingface_hub import hf_hub_download
filename = "__assets__/canny_videos_mp4/girl_turning.mp4"
repo_id = "PAIR/Text2Video-Zero"
video_path = hf_hub_download(repo_type="space", repo_id=repo_id, filename=filename)
```
2. Read video from path
```python
import imageio
reader = imageio.get_reader(video_path, "ffmpeg")
frame_count = 8
video = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]
```
3. Run `StableDiffusionControlNetPipeline` with custom trained DreamBooth model
```python
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
# set model id to custom model
model_id = "PAIR/text2video-zero-controlnet-canny-avatar"
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_id, controlnet=controlnet, torch_dtype=torch.float16
).to("cuda")
# Set the attention processor
pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
# fix latents for all frames
latents = torch.randn((1, 4, 64, 64), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)
prompt = "oil painting of a beautiful girl avatar style"
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
imageio.mimsave("video.mp4", result, fps=4)
```
You can filter out some available DreamBooth-trained models with [this link](https://huggingface.co/models?search=dreambooth).
## TextToVideoZeroPipeline
[[autodoc]] TextToVideoZeroPipeline
- all
- __call__
\ No newline at end of file
......@@ -137,6 +137,7 @@ else:
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
TextToVideoSDPipeline,
TextToVideoZeroPipeline,
UnCLIPImageVariationPipeline,
UnCLIPPipeline,
VersatileDiffusionDualGuidedPipeline,
......
......@@ -68,7 +68,7 @@ else:
StableUnCLIPPipeline,
)
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .text_to_video_synthesis import TextToVideoSDPipeline
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
from .versatile_diffusion import (
VersatileDiffusionDualGuidedPipeline,
......
......@@ -29,3 +29,4 @@ except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_text_to_video_synth import TextToVideoSDPipeline # noqa: F401
from .pipeline_text_to_video_zero import TextToVideoZeroPipeline
from dataclasses import dataclass
from typing import Callable, List, Optional, Union
import numpy as np
import PIL
import torch
import torch.nn.functional as F
from torch.nn.functional import grid_sample
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import BaseOutput
def rearrange_0(tensor, f):
F, C, H, W = tensor.size()
tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4))
return tensor
def rearrange_1(tensor):
B, C, F, H, W = tensor.size()
return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W))
def rearrange_3(tensor, f):
F, D, C = tensor.size()
return torch.reshape(tensor, (F // f, f, D, C))
def rearrange_4(tensor):
B, F, D, C = tensor.size()
return torch.reshape(tensor, (B * F, D, C))
class CrossFrameAttnProcessor:
"""
Cross frame attention processor. For each frame the self-attention is replaced with attention with first frame
Args:
batch_size: The number that represents actual batch size, other than the frames.
For example, using calling unet with a single prompt and num_images_per_prompt=1, batch_size should be
equal to 2, due to classifier-free guidance.
"""
def __init__(self, batch_size=2):
self.batch_size = batch_size
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
is_cross_attention = encoder_hidden_states is not None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# Sparse Attention
if not is_cross_attention:
video_length = key.size()[0] // self.batch_size
first_frame_index = [0] * video_length
# rearrange keys to have batch and frames in the 1st and 2nd dims respectively
key = rearrange_3(key, video_length)
key = key[:, first_frame_index]
# rearrange values to have batch and frames in the 1st and 2nd dims respectively
value = rearrange_3(value, video_length)
value = value[:, first_frame_index]
# rearrange back to original shape
key = rearrange_4(key)
value = rearrange_4(value)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
@dataclass
class TextToVideoPipelineOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
def coords_grid(batch, ht, wd, device):
# Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def warp_single_latent(latent, reference_flow):
"""
Warp latent of a single frame with given flow
Args:
latent: latent code of a single frame
reference_flow: flow which to warp the latent with
Returns:
warped: warped latent
"""
_, _, H, W = reference_flow.size()
_, _, h, w = latent.size()
coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype)
coords_t0 = coords0 + reference_flow
coords_t0[:, 0] /= W
coords_t0[:, 1] /= H
coords_t0 = coords_t0 * 2.0 - 1.0
coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear")
coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1))
warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection")
return warped
def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype):
"""
Create translation motion field
Args:
motion_field_strength_x: motion strength along x-axis
motion_field_strength_y: motion strength along y-axis
frame_ids: indexes of the frames the latents of which are being processed.
This is needed when we perform chunk-by-chunk inference
device: device
dtype: dtype
Returns:
"""
seq_length = len(frame_ids)
reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype)
for fr_idx in range(seq_length):
reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx])
reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx])
return reference_flow
def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents):
"""
Creates translation motion and warps the latents accordingly
Args:
motion_field_strength_x: motion strength along x-axis
motion_field_strength_y: motion strength along y-axis
frame_ids: indexes of the frames the latents of which are being processed.
This is needed when we perform chunk-by-chunk inference
latents: latent codes of frames
Returns:
warped_latents: warped latents
"""
motion_field = create_motion_field(
motion_field_strength_x=motion_field_strength_x,
motion_field_strength_y=motion_field_strength_y,
frame_ids=frame_ids,
device=latents.device,
dtype=latents.dtype,
)
warped_latents = latents.clone().detach()
for i in range(len(warped_latents)):
warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None])
return warped_latents
class TextToVideoZeroPipeline(StableDiffusionPipeline):
r"""
Pipeline for zero-shot text-to-video generation using Stable Diffusion.
This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods
the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
)
self.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
def forward_loop(self, x_t0, t0, t1, generator):
"""
Perform ddpm forward process from time t0 to t1. This is the same as adding noise with corresponding variance.
Args:
x_t0: latent code at time t0
t0: t0
t1: t1
generator: torch.Generator object
Returns:
x_t1: forward process applied to x_t0 from time t0 to t1.
"""
eps = torch.randn(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device)
alpha_vec = torch.prod(self.scheduler.alphas[t0:t1])
x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps
return x_t1
def backward_loop(
self,
latents,
timesteps,
prompt_embeds,
guidance_scale,
callback,
callback_steps,
num_warmup_steps,
extra_step_kwargs,
cross_attention_kwargs=None,
):
"""
Perform backward process given list of time steps
Args:
latents: Latents at time timesteps[0].
timesteps: time steps, along which to perform backward process.
prompt_embeds: Pre-generated text embeddings
guidance_scale:
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
extra_step_kwargs: extra_step_kwargs.
cross_attention_kwargs: cross_attention_kwargs.
num_warmup_steps: number of warmup steps.
Returns:
latents: latents of backward process output at time timesteps[-1]
"""
do_classifier_free_guidance = guidance_scale > 1.0
with self.progress_bar(total=len(timesteps)) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
return latents.clone().detach()
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
video_length: Optional[int] = 8,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
motion_field_strength_x: float = 12,
motion_field_strength_y: float = 12,
output_type: Optional[str] = "tensor",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
t0: int = 44,
t1: int = 47,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
video_length (`int`, *optional*, defaults to 8): The number of generated video frames
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
motion_field_strength_x (`float`, *optional*, defaults to 12):
Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439),
Sect. 3.3.1.
motion_field_strength_y (`float`, *optional*, defaults to 12):
Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439),
Sect. 3.3.1.
t0 (`int`, *optional*, defaults to 44):
Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the
[paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
t1 (`int`, *optional*, defaults to 47):
Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the
[paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
Returns:
[`~pipelines.text_to_video_synthesis.TextToVideoPipelineOutput`]:
The output contains a ndarray of the generated images, when output_type != 'latent', otherwise a latent
codes of generated image, and a list of `bool`s denoting whether the corresponding generated image
likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
assert video_length > 0
frame_ids = list(range(video_length))
assert num_videos_per_prompt == 1
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
# Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
# Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# Encode input prompt
prompt_embeds = self._encode_prompt(
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
)
# Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# Perform the first backward process up to time T_1
x_1_t1 = self.backward_loop(
timesteps=timesteps[: -t1 - 1],
prompt_embeds=prompt_embeds,
latents=latents,
guidance_scale=guidance_scale,
callback=callback,
callback_steps=callback_steps,
extra_step_kwargs=extra_step_kwargs,
num_warmup_steps=num_warmup_steps,
)
# Perform the second backward process up to time T_0
x_1_t0 = self.backward_loop(
timesteps=timesteps[-t1 - 1 : -t0 - 1],
prompt_embeds=prompt_embeds,
latents=x_1_t1,
guidance_scale=guidance_scale,
callback=callback,
callback_steps=callback_steps,
extra_step_kwargs=extra_step_kwargs,
num_warmup_steps=num_warmup_steps,
)
# Propagate first frame latents at time T_0 to remaining frames
x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1)
# Add motion in latents at time T_0
x_2k_t0 = create_motion_field_and_warp_latents(
motion_field_strength_x=motion_field_strength_x,
motion_field_strength_y=motion_field_strength_y,
latents=x_2k_t0,
frame_ids=frame_ids[1:],
)
# Perform forward process up to time T_1
x_2k_t1 = self.forward_loop(
x_t0=x_2k_t0,
t0=timesteps[-t0 - 1].item(),
t1=timesteps[-t1 - 1].item(),
generator=generator,
)
# Perform backward process from time T_1 to 0
x_1k_t1 = torch.cat([x_1_t1, x_2k_t1])
b, l, d = prompt_embeds.size()
prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d)
self.scheduler.set_timesteps(num_inference_steps, device=device)
x_1k_0 = self.backward_loop(
timesteps=timesteps[-t1 - 1 :],
prompt_embeds=prompt_embeds,
latents=x_1k_t1,
guidance_scale=guidance_scale,
callback=callback,
callback_steps=callback_steps,
extra_step_kwargs=extra_step_kwargs,
num_warmup_steps=num_warmup_steps,
)
latents = x_1k_0
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
torch.cuda.empty_cache()
if output_type == "latent":
image = latents
has_nsfw_concept = None
else:
image = self.decode_latents(latents)
# Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (image, has_nsfw_concept)
return TextToVideoPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
......@@ -407,6 +407,21 @@ class TextToVideoSDPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class TextToVideoZeroPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class UnCLIPImageVariationPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import DDIMScheduler, TextToVideoZeroPipeline
from diffusers.utils import require_torch_gpu, slow
from ...test_pipelines_common import assert_mean_pixel_difference
@slow
@require_torch_gpu
class TextToVideoZeroPipelineSlowTests(unittest.TestCase):
def test_full_model(self):
model_id = "runwayml/stable-diffusion-v1-5"
pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(0)
prompt = "A bear is playing a guitar on Times Square"
result = pipe(prompt=prompt, generator=generator).images
expected_result = torch.load(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/tree/main/text-to-video/A bear is playing a guitar on Times Square.pt"
)
assert_mean_pixel_difference(result, expected_result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment