Commit 77605806 authored by chenpangpang's avatar chenpangpang
Browse files

feat: 初始提交

parent 2f260963
Pipeline #1520 failed with stages
.idea
chenyh
.vscode
\ No newline at end of file
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [XIN MA] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
\ No newline at end of file
## Cinemo: Consistent and Controllable Image Animation with Motion Diffusion Models<br><sub>Official PyTorch Implementation</sub>
[![Arxiv](https://img.shields.io/badge/Arxiv-b31b1b.svg)](https://arxiv.org/abs/2407.15642)
[![Project Page](https://img.shields.io/badge/Project-Website-blue)](https://maxin-cn.github.io/cinemo_project/)
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)](https://huggingface.co/spaces/maxin-cn/Cinemo)
> [**Cinemo: Consistent and Controllable Image Animation with Motion Diffusion Models**](https://maxin-cn.github.io/cinemo_project/)<br>
> [Xin Ma](https://maxin-cn.github.io/), [Yaohui Wang*†](https://wyhsirius.github.io/), [Gengyun Jia](https://scholar.google.com/citations?user=_04pkGgAAAAJ&hl=zh-CN), [Xinyuan Chen](https://scholar.google.com/citations?user=3fWSC8YAAAAJ), [Yuan-Fang Li](https://users.monash.edu/~yli/), [Cunjian Chen*](https://cunjian.github.io/), [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ&hl=zh-CN) <br>
> (*Corresponding authors, †Project Lead)
This repo contains pre-trained weights, and sampling code of Cinemo. Please visit our [project page](https://maxin-cn.github.io/cinemo_project/) for more results.
<!--
In this project, we propose a novel method called Cinemo, which can perform motion-controllable image animation with strong consistency and smoothness. To improve motion smoothness, Cinemo learns the distribution of motion residuals, rather than directly generating subsequent frames. Additionally, a structural similarity index-based method is proposed to control the motion intensity. Furthermore, we propose a noise refinement technique based on discrete cosine transformation to ensure temporal consistency. These three methods help Cinemo generate highly consistent, smooth, and motion-controlled image animation results. Compared to previous methods, Cinemo offers simpler and more precise user control and better generative performance.
-->
<div align="center">
<img src="visuals/pipeline.svg">
</div>
## News
- (🔥 New) Jul. 29, 2024. 💥 [HuggingFace space](https://huggingface.co/spaces/maxin-cn/Cinemo) is added, you can also launch [gradio interface ](#gradio-interface) locally.
- (🔥 New) Jul. 23, 2024. 💥 Our paper is released on [arxiv](https://arxiv.org/abs/2407.15642).
- (🔥 New) Jun. 2, 2024. 💥 The inference code is released. The checkpoint can be found [here](https://huggingface.co/maxin-cn/Cinemo/tree/main).
## Setup
Download and set up the repo:
```bash
git clone https://github.com/maxin-cn/Cinemo
cd Cinemo
conda env create -f environment.yml
conda activate cinemo
```
<!--
We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want
to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file.
```bash
conda env create -f environment.yml
conda activate cinemo
```
-->
## Animation
You can sample from our **pre-trained Cinemo models** with [`animation.py`](pipelines/animation.py). Weights for our pre-trained Cinemo model can be found [here](https://huggingface.co/maxin-cn/Cinemo/tree/main). The script has various arguments for adjusting sampling steps, changing the classifier-free guidance scale, etc:
```bash
bash pipelines/animation.sh
```
Related model weights will be downloaded automatically and following results can be obtained,
<table style="width:100%; text-align:center;">
<tr>
<td align="center">Input image</td>
<td align="center">Output video</td>
<td align="center">Input image</td>
<td align="center">Output video</td>
</tr>
<tr>
<td align="center"><img src="visuals/animations/people_walking/0.jpg" width="100%"></td>
<td align="center"><img src="visuals/animations/people_walking/people_walking.gif" width="100%"></td>
<td align="center"><img src="visuals/animations/sea_swell/0.jpg" width="100%"></td>
<td align="center"><img src="visuals/animations/sea_swell/sea_swell.gif" width="100%"></td>
</tr>
<tr>
<td align="center" colspan="2">"People Walking"</td>
<td align="center" colspan="2">"Sea Swell"</td>
</tr>
<tr>
<td align="center"><img src="visuals/animations/girl_dancing_under_the_stars/0.jpg" width="100%"></td>
<td align="center"><img src="visuals/animations/girl_dancing_under_the_stars/girl_dancing_under_the_stars.gif" width="100%"></td>
<td align="center"><img src="visuals/animations/dragon_glowing_eyes/0.jpg" width="100%"></td>
<td align="center"><img src="visuals/animations/dragon_glowing_eyes/dragon_glowing_eyes.gif" width="100%"></td>
</tr>
<tr>
<td align="center" colspan="2">"Girl Dancing under the Stars"</td>
<td align="center" colspan="2">"Dragon Glowing Eyes"</td>
</tr>
<tr>
<td align="center"><img src="visuals/animations/bubbles__floating_upwards/0.jpg" width="100%"></td>
<td align="center"><img src="visuals/animations/bubbles__floating_upwards/bubbles__floating_upwards.gif" width="100%"></td>
<td align="center"><img src="visuals/animations/snowman_waving_his_hand/0.jpg" width="100%"></td>
<td align="center"><img src="visuals/animations/snowman_waving_his_hand/snowman_waving_his_hand.gif" width="100%"></td>
</tr>
<tr>
<td align="center" colspan="2">"Bubbles Floating upwards"</td>
<td align="center" colspan="2">"Snowman Waving his Hand"</td>
</tr>
</table>
## Gradio interface
We also provide a local gradio interface, just run:
```bash
python app.py
```
You can specify the `--share` and `--server_name` arguments to meet your requirement!
## Other Applications
You can also utilize Cinemo for other applications, such as motion transfer and video editing:
```bash
bash pipelines/video_editing.sh
```
Related checkpoints will be downloaded automatically and following results will be obtained,
<table style="width:100%; text-align:center;">
<tr>
<td align="center">Input video</td>
<td align="center">First frame</td>
<td align="center">Edited first frame</td>
<td align="center">Output video</td>
</tr>
<tr>
<td align="center"><img src="visuals/video_editing/origin/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.gif" width="100%"></td>
<td align="center"><img src="visuals/video_editing/origin/0.jpg" width="100%"></td>
<td align="center"><img src="visuals/video_editing/edit/0.jpg" width="100%"></td>
<td align="center"><img src="visuals/video_editing/edit/editing_a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.gif" width="100%"></td>
</tr>
</table>
or motion transfer,
<table style="width:100%; text-align:center;">
<tr>
<td align="center">Input video</td>
<td align="center">First frame</td>
<td align="center">Edited first frame</td>
<td align="center">Output video</td>
</tr>
<tr>
<td align="center"><img src="visuals/motion_transfer/origin/a_man_walking_on_the_beach.gif" width="100%"></td>
<td align="center"><img src="visuals/motion_transfer/origin/0.jpg" width="100%"></td>
<td align="center"><img src="visuals/motion_transfer/edit/0.jpg" width="100%"></td>
<td align="center"><img src="visuals/motion_transfer/edit/a_man_walking_in_the_park.gif" width="100%"></td>
</tr>
</table>
## Contact Us
Xin Ma: xin.ma1@monash.edu,
Yaohui Wang: wangyaohui@pjlab.org.cn
## Citation
If you find this work useful for your research, please consider citing it.
```bibtex
@article{ma2024cinemo,
title={Cinemo: Latent Diffusion Transformer for Video Generation},
author={Ma, Xin and Wang, Yaohui and Jia, Gengyun and Chen, Xinyuan and Li, Yuan-Fang and Chen, Cunjian and Qiao, Yu},
journal={arXiv preprint arXiv:2407.15642},
year={2024}
}
```
## Acknowledgments
Cinemo has been greatly inspired by the following amazing works and teams: [LaVie](https://github.com/Vchitect/LaVie) and [SEINE](https://github.com/Vchitect/SEINE), we thank all the contributors for open-sourcing.
## License
The code and model weights are licensed under [LICENSE](LICENSE).
import gradio as gr
import torch
import argparse
from pipelines.pipeline_videogen import VideoGenPipeline
from diffusers.schedulers import DDIMScheduler
from diffusers.models import AutoencoderKL
from diffusers.models import AutoencoderKLTemporalDecoder
from transformers import CLIPTokenizer, CLIPTextModel
from omegaconf import OmegaConf
import os, sys
sys.path.append(os.path.split(sys.path[0])[0])
from models import get_models
import imageio
from PIL import Image
import numpy as np
from datasets import video_transforms
from torchvision import transforms
from einops import rearrange, repeat
from utils import dct_low_pass_filter, exchanged_mixed_dct_freq
from copy import deepcopy
import requests
from datetime import datetime
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/sample.yaml")
args = parser.parse_args()
args = OmegaConf.load(args.config)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 # torch.float16
unet = get_models(args).to(device, dtype=dtype)
if args.enable_vae_temporal_decoder:
if args.use_dct:
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float64).to(device)
else:
vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
else:
vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64)
vae = deepcopy(vae_for_base_content).to(dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device) # huge
# set eval mode
unet.eval()
vae.eval()
text_encoder.eval()
basedir = os.getcwd()
savedir = os.path.join(basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
savedir_sample = os.path.join(savedir, "sample")
os.makedirs(savedir, exist_ok=True)
def update_and_resize_image(input_image_path, height_slider, width_slider):
if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
else:
pil_image = Image.open(input_image_path).convert('RGB')
original_width, original_height = pil_image.size
if original_height == height_slider and original_width == width_slider:
return gr.Image(value=np.array(pil_image))
ratio1 = height_slider / original_height
ratio2 = width_slider / original_width
if ratio1 > ratio2:
new_width = int(original_width * ratio1)
new_height = int(original_height * ratio1)
else:
new_width = int(original_width * ratio2)
new_height = int(original_height * ratio2)
pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
left = (new_width - width_slider) / 2
top = (new_height - height_slider) / 2
right = left + width_slider
bottom = top + height_slider
pil_image = pil_image.crop((left, top, right, bottom))
return gr.Image(value=np.array(pil_image))
def update_textbox_and_save_image(input_image, height_slider, width_slider):
pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
original_width, original_height = pil_image.size
ratio1 = height_slider / original_height
ratio2 = width_slider / original_width
if ratio1 > ratio2:
new_width = int(original_width * ratio1)
new_height = int(original_height * ratio1)
else:
new_width = int(original_width * ratio2)
new_height = int(original_height * ratio2)
pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
left = (new_width - width_slider) / 2
top = (new_height - height_slider) / 2
right = left + width_slider
bottom = top + height_slider
pil_image = pil_image.crop((left, top, right, bottom))
img_path = os.path.join(savedir, "input_image.png")
pil_image.save(img_path)
return gr.Textbox(value=img_path), gr.Image(value=np.array(pil_image))
def prepare_image(image, vae, transform_video, device, dtype=torch.float16):
image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2)
image = transform_video(image)
image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor)
image = image.unsqueeze(2)
return image
def gen_video(input_image, prompt, negative_prompt, diffusion_step, height, width, scfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed):
torch.manual_seed(seed)
scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule)
videogen_pipeline = VideoGenPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
unet=unet).to(device)
# videogen_pipeline.enable_xformers_memory_efficient_attention()
transform_video = transforms.Compose([
video_transforms.ToTensorVideo(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
if args.use_dct:
base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device)
else:
base_content = prepare_image(input_image, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)
if use_dctinit:
# filter params
print("Using DCT!")
base_content_repeat = repeat(base_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()
# define filter
freq_filter = dct_low_pass_filter(dct_coefficients=base_content, percentage=dct_coefficients)
noise = torch.randn(1, 4, 15, 40, 64).to(device)
# add noise to base_content
diffuse_timesteps = torch.full((1,),int(noise_level))
diffuse_timesteps = diffuse_timesteps.long()
# 3d content
base_content_noise = scheduler.add_noise(
original_samples=base_content_repeat.to(device),
noise=noise,
timesteps=diffuse_timesteps.to(device))
# 3d content
latents = exchanged_mixed_dct_freq(noise=noise,
base_content=base_content_noise,
LPF_3d=freq_filter).to(dtype=torch.float16)
base_content = base_content.to(dtype=torch.float16)
videos = videogen_pipeline(prompt,
negative_prompt=negative_prompt,
latents=latents if use_dctinit else None,
base_content=base_content,
video_length=15,
height=height,
width=width,
num_inference_steps=diffusion_step,
guidance_scale=scfg_scale,
motion_bucket_id=100-motion_bucket_id,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
save_path = args.save_img_path + 'temp' + '.mp4'
# torchvision.io.write_video(save_path, videos[0], fps=8, video_codec='h264', options={'crf': '10'})
imageio.mimwrite(save_path, videos[0], fps=8, quality=7)
return save_path
if not os.path.exists(args.save_img_path):
os.makedirs(args.save_img_path)
with gr.Blocks() as demo:
gr.Markdown("<font color=red size=6.5><center>Cinemo: Consistent and Controllable Image Animation with Motion Diffusion Models</center></font>")
gr.Markdown(
"""<div style="display: flex;align-items: center;justify-content: center">
[<a href="https://arxiv.org/abs/2407.15642">Arxiv Report</a>] | [<a href="https://https://maxin-cn.github.io/cinemo_project/">Project Page</a>] | [<a href="https://github.com/maxin-cn/Cinemo">Github</a>]</div>
"""
)
with gr.Column(variant="panel"):
with gr.Row():
prompt_textbox = gr.Textbox(label="Prompt", lines=1)
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1)
with gr.Row(equal_height=False):
with gr.Column():
with gr.Row():
input_image = gr.Image(label="Input Image", interactive=True)
result_video = gr.Video(label="Generated Animation", interactive=False, autoplay=True)
generate_button = gr.Button(value="Generate", variant='primary')
with gr.Accordion("Advanced options", open=False):
gr.Markdown(
"""
- Input image can be specified using the "Input Image URL" text box or uploaded by clicking or dragging the image to the "Input Image" box.
- Input image will be resized and/or center cropped to a given resolution (320 x 512) automatically.
- After setting the input image path, press the "Preview" button to visualize the resized input image.
"""
)
with gr.Column():
with gr.Row():
input_image_path = gr.Textbox(label="Input Image URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
preview_button = gr.Button(value="Preview")
with gr.Row():
sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=250, step=1)
with gr.Row():
seed_textbox = gr.Slider(label="Seed", value=100, minimum=1, maximum=int(1e8), step=1, interactive=True)
# seed_textbox = gr.Textbox(label="Seed", value=100)
# seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
# seed_button.click(fn=lambda: gr.Textbox(value=random.randint(1, int(1e8))), inputs=[], outputs=[seed_textbox])
with gr.Row():
height = gr.Slider(label="Height", value=320, minimum=0, maximum=512, step=16, interactive=False)
width = gr.Slider(label="Width", value=512, minimum=0, maximum=512, step=16, interactive=False)
with gr.Row():
txt_cfg_scale = gr.Slider(label="CFG Scale", value=7.5, minimum=1.0, maximum=20.0, step=0.1, interactive=True)
motion_bucket_id = gr.Slider(label="Motion Intensity", value=10, minimum=1, maximum=20, step=1, interactive=True)
with gr.Row():
use_dctinit = gr.Checkbox(label="Enable DCTInit", value=True)
dct_coefficients = gr.Slider(label="DCT Coefficients", value=0.23, minimum=0, maximum=1, step=0.01, interactive=True)
noise_level = gr.Slider(label="Noise Level", value=985, minimum=1, maximum=999, step=1, interactive=True)
input_image.upload(fn=update_textbox_and_save_image, inputs=[input_image, height, width], outputs=[input_image_path, input_image])
preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height, width], outputs=[input_image])
EXAMPLES = [
["./example/aircrafts_flying/0.jpg", "aircrafts flying" , "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
["./example/fireworks/0.jpg", "fireworks" , "low quality", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
["./example/flowers_swaying/0.jpg", "flowers swaying" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
["./example/girl_walking_on_the_beach/0.jpg", "girl walking on the beach" , "low quality, background changing", 50, 320, 512, 7.5, True, 0.25, 995, 10, 49494220],
["./example/house_rotating/0.jpg", "house rotating" , "low quality", 50, 320, 512, 7.5, True, 0.23, 985, 10, 46640174],
["./example/people_runing/0.jpg", "people runing" , "low quality, background changing", 50, 320, 512, 7.5, True, 0.23, 975, 10, 100],
["./example/shark_swimming/0.jpg", "shark swimming" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 32947978],
["./example/car_moving/0.jpg", "car moving" , "", 50, 320, 512, 7.5, True, 0.23, 975, 10, 75469653],
["./example/windmill_turning/0.jpg", "windmill turning" , "background changing", 50, 320, 512, 7.5, True, 0.21, 975, 10, 89378613],
]
examples = gr.Examples(
examples = EXAMPLES,
fn = gen_video,
inputs=[input_image, prompt_textbox, negative_prompt_textbox, sample_step_slider, height, width, txt_cfg_scale, use_dctinit, dct_coefficients, noise_level, motion_bucket_id, seed_textbox],
outputs=[result_video],
# cache_examples=True,
cache_examples="lazy",
)
generate_button.click(
fn=gen_video,
inputs=[
input_image,
prompt_textbox,
negative_prompt_textbox,
sample_step_slider,
height,
width,
txt_cfg_scale,
use_dctinit,
dct_coefficients,
noise_level,
motion_bucket_id,
seed_textbox,
],
outputs=[result_video]
)
demo.launch(debug=False, share=True, server_name="0.0.0.0")
\ No newline at end of file
# ckpt
ckpt: # not used
save_img_path: "./sample_videos/"
pretrained_model_path: "maxin-cn/Cinemo"
# model config:
model: UNet
video_length: 15
image_size: [320, 512]
# beta schedule
beta_start: 0.0001
beta_end: 0.02
beta_schedule: "linear"
# model speedup
use_compile: False
use_fp16: True
# sample config:
seed:
run_time: 0
use_dct: True
guidance_scale: 7.5 #
motion_bucket_id: 95 # [0-19] The larger the value, the smaller the motion intensity
sample_method: 'DDIM'
num_sampling_steps: 50
enable_vae_temporal_decoder: True
image_prompts: [
['aircraft.jpg', 'aircrafts flying'],
['car.jpg' ,"car moving"],
['fireworks.jpg', 'fireworks'],
['flowers.jpg', 'flowers swaying'],
['forest.jpg', 'people walking'],
['shark_unwater.jpg', 'shark falling into the sea'],
]
import torch
import random
import numbers
from torchvision.transforms import RandomCrop, RandomResizedCrop
from PIL import Image
from torchvision.utils import _log_api_usage_once
def _is_tensor_video_clip(clip):
if not torch.is_tensor(clip):
raise TypeError("clip should be Tensor. Got %s" % type(clip))
if not clip.ndimension() == 4:
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
return True
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
def crop(clip, i, j, h, w):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
"""
if len(clip.size()) != 4:
raise ValueError("clip should be a 4D tensor")
return clip[..., i : i + h, j : j + w]
def resize(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
def resize_scale(clip, target_size, interpolation_mode):
if len(target_size) != 2:
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size[0] / min(H, W)
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
def resize_with_scale_factor(clip, scale_factor, interpolation_mode):
return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False)
def resize_scale_with_height(clip, target_size, interpolation_mode):
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size / H
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
def resize_scale_with_weight(clip, target_size, interpolation_mode):
H, W = clip.size(-2), clip.size(-1)
scale_ = target_size / W
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
"""
Do spatial cropping and resizing to the video clip
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the cropped region.
w (int): Width of the cropped region.
size (tuple(int, int)): height and width of resized clip
Returns:
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
clip = crop(clip, i, j, h, w)
clip = resize(clip, size, interpolation_mode)
return clip
def center_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
# print(clip.shape)
th, tw = crop_size
if h < th or w < tw:
# print(h, w)
raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w))
i = int(round((h - th) / 2.0))
j = int(round((w - tw) / 2.0))
return crop(clip, i, j, th, tw), i, j
def center_crop_using_short_edge(clip):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h < w:
th, tw = h, h
i = 0
j = int(round((w - tw) / 2.0))
else:
th, tw = w, w
i = int(round((h - th) / 2.0))
j = 0
return crop(clip, i, j, th, tw)
def random_shift_crop(clip):
'''
Slide along the long edge, with the short edge as crop size
'''
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
if h <= w:
long_edge = w
short_edge = h
else:
long_edge = h
short_edge =w
th, tw = short_edge, short_edge
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return crop(clip, i, j, th, tw), i, j
def random_crop(clip, crop_size):
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
h, w = clip.size(-2), clip.size(-1)
th, tw = crop_size[-2], crop_size[-1]
if h < th or w < tw:
raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w))
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
clip_crop = crop(clip, i, j, th, tw)
return clip_crop, i, j
def to_tensor(clip):
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
_is_tensor_video_clip(clip)
if not clip.dtype == torch.uint8:
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
# return clip.float().permute(3, 0, 1, 2) / 255.0
return clip.float() / 255.0
def normalize(clip, mean, std, inplace=False):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
mean (tuple): pixel RGB mean. Size is (3)
std (tuple): pixel standard deviation. Size is (3)
Returns:
normalized clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
if not inplace:
clip = clip.clone()
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
# print(mean)
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
def hflip(clip):
"""
Args:
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
Returns:
flipped clip (torch.tensor): Size is (T, C, H, W)
"""
if not _is_tensor_video_clip(clip):
raise ValueError("clip should be a 4D torch.tensor")
return clip.flip(-1)
class RandomCropVideo:
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: randomly cropped video clip.
size is (T, C, OH, OW)
"""
i, j, h, w = self.get_params(clip)
return crop(clip, i, j, h, w)
def get_params(self, clip):
h, w = clip.shape[-2:]
th, tw = self.size
if h < th or w < tw:
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
if w == tw and h == th:
return 0, 0, h, w
i = torch.randint(0, h - th + 1, size=(1,)).item()
j = torch.randint(0, w - tw + 1, size=(1,)).item()
return i, j, th, tw
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class CenterCropResizeVideo:
'''
First use the short side for cropping length,
center crop video, then resize to the specified size
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
# print(clip.shape)
clip_center_crop = center_crop_using_short_edge(clip)
# print(clip_center_crop.shape) 320 512
clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode)
return clip_center_crop_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class SDXL:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
# add aditional one pixel for avoiding error in center crop
ori_h, ori_w = clip.size(-2), clip.size(-1)
tar_h, tar_w = self.size[0] + 1, self.size[1] + 1
# if ori_h >= tar_h and ori_w >= tar_w:
# clip_tar_crop, i, j = random_crop(clip=clip, crop_size=self.size)
# else:
# tar_h_div_ori_h = tar_h / ori_h
# tar_w_div_ori_w = tar_w / ori_w
# if tar_h_div_ori_h > tar_w_div_ori_w:
# clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
# else:
# clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
# clip_tar_crop, i, j = random_crop(clip, self.size)
if ori_h >= tar_h and ori_w >= tar_w:
tar_h_div_ori_h = tar_h / ori_h
tar_w_div_ori_w = tar_w / ori_w
if tar_h_div_ori_h > tar_w_div_ori_w:
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
else:
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
ori_h, ori_w = clip.size(-2), clip.size(-1)
clip_tar_crop, i, j = random_crop(clip, self.size)
else:
tar_h_div_ori_h = tar_h / ori_h
tar_w_div_ori_w = tar_w / ori_w
if tar_h_div_ori_h > tar_w_div_ori_w:
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
else:
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
clip_tar_crop, i, j = random_crop(clip, self.size)
return clip_tar_crop, ori_h, ori_w, i, j
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class SDXLCenterCrop:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
# add aditional one pixel for avoiding error in center crop
ori_h, ori_w = clip.size(-2), clip.size(-1)
tar_h, tar_w = self.size[0] + 1, self.size[1] + 1
tar_h_div_ori_h = tar_h / ori_h
tar_w_div_ori_w = tar_w / ori_w
# print('before resize', clip.shape)
if tar_h_div_ori_h > tar_w_div_ori_w:
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode)
# print('after h resize', clip.shape)
else:
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode)
# print('after resize', clip.shape)
# print(clip.shape)
# clip_tar_crop, i, j = random_crop(clip, self.size)
clip_tar_crop, i, j = center_crop(clip, self.size)
# print('after crop', clip_tar_crop.shape)
return clip_tar_crop, ori_h, ori_w, i, j
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class InternVideo320512:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
# add aditional one pixel for avoiding error in center crop
h, w = clip.size(-2), clip.size(-1)
# print('before resize', clip.shape)
if h < 320:
clip = resize_scale_with_height(clip=clip, target_size=321, interpolation_mode=self.interpolation_mode)
# print('after h resize', clip.shape)
if w < 512:
clip = resize_scale_with_weight(clip=clip, target_size=513, interpolation_mode=self.interpolation_mode)
# print('after w resize', clip.shape)
# print(clip.shape)
clip_center_crop = center_crop(clip, self.size)
clip_center_crop_no_subtitles = center_crop(clip, (220, 352))
clip_center_resize = resize(clip_center_crop_no_subtitles, target_size=self.size, interpolation_mode=self.interpolation_mode)
# print(clip_center_crop.shape)
return clip_center_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class CenterCropVideo:
'''
First scale to the specified size in equal proportion to the short edge,
then center cropping
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
clip_center_crop = center_crop(clip_resize, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class KineticsRandomCropResizeVideo:
'''
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
clip_random_crop = random_shift_crop(clip)
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
return clip_resize
class ResizeVideo():
'''
First use the short side for cropping length,
center crop video, then resize to the specified size
'''
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: scale resized / center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
return clip_resize
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class CenterCropVideo:
def __init__(
self,
size,
interpolation_mode="bilinear",
):
if isinstance(size, tuple):
if len(size) != 2:
raise ValueError(f"size should be tuple (height, width), instead got {size}")
self.size = size
else:
self.size = (size, size)
self.interpolation_mode = interpolation_mode
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
Returns:
torch.tensor: center cropped video clip.
size is (T, C, crop_size, crop_size)
"""
clip_center_crop = center_crop(clip, self.size)
return clip_center_crop
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
class NormalizeVideo:
"""
Normalize the video clip by mean subtraction and division by standard deviation
Args:
mean (3-tuple): pixel RGB mean
std (3-tuple): pixel RGB standard deviation
inplace (boolean): whether do in-place normalization
"""
def __init__(self, mean, std, inplace=False):
self.mean = mean
self.std = std
self.inplace = inplace
def __call__(self, clip):
"""
Args:
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
"""
return normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class ToTensorVideo:
"""
Convert tensor data type from uint8 to float, divide value by 255.0 and
permute the dimensions of clip tensor
"""
def __init__(self):
pass
def __call__(self, clip):
"""
Args:
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
Return:
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
"""
return to_tensor(clip)
def __repr__(self) -> str:
return self.__class__.__name__
class RandomHorizontalFlipVideo:
"""
Flip the video clip along the horizontal direction with a given probability
Args:
p (float): probability of the clip being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, clip):
"""
Args:
clip (torch.tensor): Size is (T, C, H, W)
Return:
clip (torch.tensor): Size is (T, C, H, W)
"""
if random.random() < self.p:
clip = hflip(clip)
return clip
def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"
class Compose:
"""Composes several transforms together. This transform does not support torchscript.
Please, see the note below.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
.. note::
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
>>> transforms = torch.nn.Sequential(
>>> transforms.CenterCrop(10),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> )
>>> scripted_transforms = torch.jit.script(transforms)
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
`lambda` functions or ``PIL.Image``.
"""
def __init__(self, transforms):
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(self)
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
if isinstance(t, SDXLCenterCrop) or isinstance(t, SDXL):
img, ori_h, ori_w, crops_coords_top, crops_coords_left = t(img)
else:
img = t(img)
return img, ori_h, ori_w, crops_coords_top, crops_coords_left
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += f" {t}"
format_string += "\n)"
return format_string
# ------------------------------------------------------------
# --------------------- Sampling ---------------------------
# ------------------------------------------------------------
class TemporalRandomCrop(object):
"""Temporally crop the given frame indices at a random location.
Args:
size (int): Desired length of frames will be seen in the model.
"""
def __init__(self, size):
self.size = size
def __call__(self, total_frames):
rand_end = max(0, total_frames - self.size - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + self.size, total_frames)
return begin_index, end_index
if __name__ == '__main__':
from torchvision import transforms
import torchvision.io as io
import numpy as np
from torchvision.utils import save_image
import os
vframes, aframes, info = io.read_video(
filename='./v_Archery_g01_c03.avi',
pts_unit='sec',
output_format='TCHW'
)
trans = transforms.Compose([
ToTensorVideo(),
RandomHorizontalFlipVideo(),
UCFCenterCropVideo(512),
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
target_video_len = 32
frame_interval = 1
total_frames = len(vframes)
print(total_frames)
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
# Sampling video frames
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
# print(start_frame_ind)
# print(end_frame_ind)
assert end_frame_ind - start_frame_ind >= target_video_len
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
print(frame_indice)
select_vframes = vframes[frame_indice]
print(select_vframes.shape)
print(select_vframes.dtype)
select_vframes_trans = trans(select_vframes)
print(select_vframes_trans.shape)
print(select_vframes_trans.dtype)
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
print(select_vframes_trans_int.dtype)
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
for i in range(target_video_len):
save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1))
\ No newline at end of file
name: cinemo
channels:
- pytorch
- nvidia
dependencies:
- python >= 3.10
- pytorch >= 2.0
- torchvision
- pytorch-cuda >= 11.7
- pip:
- timm
- diffusers[torch]==0.24.0
- accelerate
- python-hostlist
- tensorboard
- einops
- transformers
- av
- scikit-image
- decord
- pandas
- imageio-ffmpeg
- torch_dct
- omegaconf
import os
import sys
sys.path.append(os.path.split(sys.path[0])[0])
from .unet import UNet3DConditionModel
from torch.optim.lr_scheduler import LambdaLR
def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
from torch.optim.lr_scheduler import LambdaLR
def fn(step):
if warmup_steps > 0:
return min(step / warmup_steps, 1)
else:
return 1
return LambdaLR(optimizer, fn)
def get_lr_scheduler(optimizer, name, **kwargs):
if name == 'warmup':
return customized_lr_scheduler(optimizer, **kwargs)
elif name == 'cosine':
from torch.optim.lr_scheduler import CosineAnnealingLR
return CosineAnnealingLR(optimizer, **kwargs)
else:
raise NotImplementedError(name)
def get_models(args):
if 'UNet' in args.model:
return UNet3DConditionModel.from_pretrained(args.pretrained_model_path, subfolder="unet")
else:
raise '{} Model Not Supported!'.format(args.model)
\ No newline at end of file
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import nn
try:
from diffusers.utils import maybe_allow_in_graph
except:
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import Attention
from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
from diffusers.models.lora import LoRACompatibleLinear
from einops import rearrange, repeat
try:
from temporal_attention import TemporalAttention, CrossAttention, PseudoCrossAttention
except:
from .temporal_attention import TemporalAttention, CrossAttention, PseudoCrossAttention
@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
final_dropout: bool = False,
rotary_emb=None,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Temporal-Attn
self.attn_temp = TemporalAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
upcast_attention=upcast_attention,
rotary_emb=rotary_emb,
)
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
# Temporal text cross attention
# self.attn_temp_text = CrossAttention(query_dim=dim,
# cross_attention_dim=cross_attention_dim,
# heads=num_attention_heads,
# dim_head=attention_head_dim,
# dropout=dropout,
# bias=attention_bias,
# upcast_attention=upcast_attention,
# )
# self.norm_temp_text = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
# nn.init.zeros_(self.attn_temp_text.to_out[0].weight.data)
# 5. Feed-forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
video_length=None,
use_image_num=None,
):
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# Temporal Attention
if self.training and use_image_num != 0:
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
hidden_states_video = hidden_states[:, :video_length, :]
hidden_states_image = hidden_states[:, video_length:, :]
# with torch.cuda.amp.autocast(dtype=torch.float32):
norm_hidden_states_video = (
self.norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states_video)
)
hidden_states_video = self.attn_temp(norm_hidden_states_video) + hidden_states_video
# # # Temporal Text Cross Attention
# encoder_hidden_states_reshape = rearrange(encoder_hidden_states, '(b f) d c -> b f d c', f=video_length + use_image_num).contiguous()
# encoder_hidden_states_video = encoder_hidden_states_reshape[:, 0, ...].contiguous()
# encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b d c -> (b t) d c', t=d).contiguous()
# norm_hidden_states_video = (
# self.norm_temp_text(hidden_states_video, timestep) if self.use_ada_layer_norm else self.norm_temp_text(hidden_states_video)
# )
# hidden_states_video = self.attn_temp_text(norm_hidden_states_video, encoder_hidden_states=encoder_hidden_states_video) + hidden_states_video
# ################## end Temporal Text Cross Attention ###################
hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
# hidden_states = torch.cat([hidden_states_video.to(hidden_states_image.dtype), hidden_states_image], dim=1)
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
else:
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
norm_hidden_states = (
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
)
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
# # # Temporal Text Cross Attention
# encoder_hidden_states_reshape = rearrange(encoder_hidden_states, '(b f) d c -> b f d c', f=video_length + use_image_num).contiguous()
# encoder_hidden_states_video = encoder_hidden_states_reshape[:, 0, ...].contiguous()
# encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b d c -> (b t) d c', t=d).contiguous()
# norm_hidden_states = (
# self.norm_temp_text(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp_text(hidden_states)
# )
# hidden_states = self.attn_temp_text(norm_hidden_states, encoder_hidden_states=encoder_hidden_states_video) + hidden_states
# ################# end Temporal Text Cross Attention ###################
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh")
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
"""
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
self.approximate = approximate
def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate, approximate=self.approximate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
class GEGLU(nn.Module):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
class ApproximateGELU(nn.Module):
"""
The approximate form of Gaussian Error Linear Unit (GELU)
For more details, see section 2: https://arxiv.org/abs/1606.08415
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
def forward(self, x):
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
class AdaLayerNorm(nn.Module):
"""
Norm layer modified to incorporate timestep embeddings.
"""
def __init__(self, embedding_dim, num_embeddings):
super().__init__()
self.emb = nn.Embedding(num_embeddings, embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
def forward(self, x, timestep):
emb = self.linear(self.silu(self.emb(timestep)))
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x
class AdaLayerNormZero(nn.Module):
"""
Norm layer adaptive layer norm zero (adaLN-Zero).
"""
def __init__(self, embedding_dim, num_embeddings):
super().__init__()
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, timestep, class_labels, hidden_dtype=None):
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaGroupNorm(nn.Module):
"""
GroupNorm layer modified to incorporate timestep embeddings.
"""
def __init__(
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
):
super().__init__()
self.num_groups = num_groups
self.eps = eps
if act_fn is None:
self.act = None
else:
self.act = get_activation(act_fn)
self.linear = nn.Linear(embedding_dim, out_dim * 2)
def forward(self, x, emb):
if self.act:
emb = self.act(emb)
emb = self.linear(emb)
emb = emb[:, :, None, None]
scale, shift = emb.chunk(2, dim=1)
x = F.group_norm(x, self.num_groups, eps=self.eps)
x = x * (1 + scale) + shift
return x
# Copyright 2023 The HuggingFace Team. All rights reserved.
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.activations import get_activation
from diffusers.models.normalization import AdaGroupNorm
from diffusers.models.attention_processor import SpatialNorm
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
from einops import rearrange
class InflatedConv3d(nn.Conv2d):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class Upsample3D(nn.Module):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose:
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
# conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(self, hidden_states, output_size=None):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(hidden_states)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.Conv2d_0(hidden_states)
return hidden_states
class Downsample3D(nn.Module):
"""A 2D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
# conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.Conv2d_0 = conv
self.conv = conv
elif name == "Conv2d_0":
self.conv = conv
else:
self.conv = conv
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlock3D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
groups_out (`int`, *optional*, default to None):
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
`conv_shortcut` output.
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
If None, same as `out_channels`.
"""
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
skip_time_act=False,
time_embedding_norm="default", # default, scale_shift, ada_group, spatial
kernel=None,
output_scale_factor=1.0,
use_in_shortcut=None,
up=False,
down=False,
conv_shortcut_bias: bool = True,
conv_2d_out_channels: Optional[int] = None,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
if groups_out is None:
groups_out = groups
if self.time_embedding_norm == "ada_group":
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm1 = SpatialNorm(in_channels, temb_channels)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
# self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
else:
self.time_emb_proj = None
if self.time_embedding_norm == "ada_group":
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm2 = SpatialNorm(out_channels, temb_channels)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
# self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = InflatedConv3d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity)
self.upsample = self.downsample = None
if self.up:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else:
self.upsample = Upsample3D(in_channels, use_conv=False)
elif self.down:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else:
self.downsample = Downsample3D(in_channels, use_conv=False, padding=1, name="op")
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
# self.conv_shortcut = LoRACompatibleConv(
# in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
# )
self.conv_shortcut = InflatedConv3d(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
)
def forward(self, input_tensor, temb):
hidden_states = input_tensor
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm1(hidden_states, temb)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states)
# print(self.time_emb_proj) # LoRACompatibleLinear(in_features=1280, out_features=320, bias=True)
# print(self.nonlinearity) # SiLU()
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)
# temb = temb[:, :, None, None, None]
# if self.training:
# temb = rearrange(temb, 'b f d -> b d f')[..., None, None]
# else:
# temb = temb[:, :, None, None, None]
temb = temb[:, :, None, None, None]
# print(temb.shape)
if temb is not None and self.time_embedding_norm == "default":
# print(hidden_states.shape)
hidden_states = hidden_states + temb
# torch.Size([2, 320, 21, 32, 32])
# torch.Size([2, 320, 1, 1, 1])
# torch.Size([2, 320, 21, 32, 32])
# torch.Size([2, 320, 1, 1, 1])
# torch.Size([2, 640, 21, 16, 16])
# torch.Size([2, 640, 1, 1, 1])
# torch.Size([2, 640, 21, 16, 16])
# torch.Size([2, 640, 1, 1, 1])
# torch.Size([2, 1280, 21, 8, 8])
# torch.Size([2, 1280, 1, 1, 1])
# torch.Size([2, 1280, 21, 8, 8])
# torch.Size([2, 1280, 1, 1, 1])
# torch.Size([2, 1280, 21, 4, 4])
# torch.Size([2, 1280, 1, 1, 1])
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm2(hidden_states, temb)
else:
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
\ No newline at end of file
from math import pi, log
import torch
from torch import nn, einsum
from einops import rearrange, repeat
# helper functions
def exists(val):
return val is not None
def broadcat(tensors, dim = -1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim = dim)
# rotary embedding helper functions
def rotate_half(x):
x = rearrange(x, '... (d r) -> ... d r', r = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d r -> ... (d r)')
def apply_rotary_emb(freqs, t, start_index = 0, scale = 1.):
freqs = freqs.to(t)
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
return torch.cat((t_left, t, t_right), dim = -1)
# learned rotation helpers
def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
if exists(freq_ranges):
rotations = einsum('..., f -> ... f', rotations, freq_ranges)
rotations = rearrange(rotations, '... r f -> ... (r f)')
rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
return apply_rotary_emb(rotations, t, start_index = start_index)
# classes
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
custom_freqs = None,
freqs_for = 'lang',
theta = 10000,
max_freq = 10,
num_freqs = 1,
learned_freq = False,
use_xpos = False,
xpos_scale_base = 512,
interpolate_factor = 1.,
theta_rescale_factor = 1.
):
super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
theta *= theta_rescale_factor ** (dim / (dim - 2))
if exists(custom_freqs):
freqs = custom_freqs
elif freqs_for == 'lang':
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
elif freqs_for == 'pixel':
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
elif freqs_for == 'constant':
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f'unknown modality {freqs_for}')
self.cache = dict()
self.cache_scale = dict()
# self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
self.register_buffer('freqs', freqs)
# interpolation factors
assert interpolate_factor >= 1.
self.interpolate_factor = interpolate_factor
# xpos
self.use_xpos = use_xpos
if not use_xpos:
self.register_buffer('scale', None)
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = xpos_scale_base
self.register_buffer('scale', scale)
def get_seq_pos(self, seq_len, device, dtype, offset = 0):
return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
def rotate_queries_or_keys(self, t, seq_dim = -2, offset = 0):
assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
freqs = self.forward(lambda: self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset), cache_key = f'freqs:{seq_len}|offset:{offset}')
return apply_rotary_emb(freqs, t)
def rotate_queries_and_keys(self, q, k, seq_dim = -2):
assert self.use_xpos
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
freqs = self.forward(lambda: seq, cache_key = f'freqs:{seq_len}')
scale = self.get_scale(lambda: seq, cache_key = f'scale:{seq_len}').to(dtype)
rotated_q = apply_rotary_emb(freqs, q, scale = scale)
rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1)
return rotated_q, rotated_k
def get_scale(self, t, cache_key = None):
assert self.use_xpos
if exists(cache_key) and cache_key in self.cache:
return self.cache[cache_key]
if callable(t):
t = t()
scale = 1.
if self.use_xpos:
power = (t - len(t) // 2) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
if exists(cache_key):
self.cache[cache_key] = scale
return scale
def forward(self, t, cache_key = None):
if exists(cache_key) and cache_key in self.cache:
return self.cache[cache_key]
if callable(t):
t = t()
freqs = self.freqs
freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
if exists(cache_key):
self.cache[cache_key] = freqs
return freqs
import torch
from torch import nn
from typing import Optional
from dataclasses import dataclass
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
import torch.nn.functional as F
from einops import rearrange, repeat
import math
@dataclass
class Transformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
def exists(x):
return x is not None
class CrossAttention(nn.Module):
r"""
copy from diffuser 0.11.1
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
use_relative_position: bool = False,
):
super().__init__()
# print('num head', heads)
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self._slice_size = None
self._use_memory_efficient_attention_xformers = False # No use xformers for temporal attention
self.added_kv_proj_dim = added_kv_proj_dim
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
else:
self.group_norm = None
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def reshape_for_scores(self, tensor):
# split heads and dims
# tensor should be [b (h w)] f (d nd)
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).contiguous()
return tensor
def same_batch_dim_to_heads(self, tensor):
batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
return tensor
def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
self._slice_size = slice_size
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
# print('before reshpape query shape', query.shape)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
# print('after reshape query shape', query.shape)
if self.added_kv_proj_dim is not None:
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
else:
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# do not use xformers for temporal attention
# # attention, what we cannot get enough of
# if self._use_memory_efficient_attention_xformers:
# hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# # Some versions of xformers return output in fp32, cast it back to the dtype of the input
# hidden_states = hidden_states.to(query.dtype)
# else:
# if self._slice_size is None or query.shape[0] // self._slice_size == 1:
# hidden_states = self._attention(query, key, value, attention_mask)
# else:
# hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
hidden_states = self._attention(query, key, value, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def _attention(self, query, key, value, attention_mask=None):
if self.upcast_attention:
query = query.float()
key = key.float()
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
# print('query shape', query.shape)
# print('key shape', key.shape)
# print('value shape', value.shape)
if attention_mask is not None:
# print('attention_mask', attention_mask.shape)
# print('attention_scores', attention_scores.shape)
# exit()
attention_scores = attention_scores + attention_mask
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
# print(attention_probs.shape)
# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)
# print(attention_probs.shape)
# compute attention output
hidden_states = torch.bmm(attention_probs, value)
# print(hidden_states.shape)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
# print(hidden_states.shape)
# exit()
return hidden_states
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
if self.upcast_attention:
query_slice = query_slice.float()
key_slice = key_slice.float()
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
query_slice,
key_slice.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
if attention_mask is not None:
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
if self.upcast_softmax:
attn_slice = attn_slice.float()
attn_slice = attn_slice.softmax(dim=-1)
# cast back to the original dtype
attn_slice = attn_slice.to(value.dtype)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
# TODO attention_mask
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# print(query.shape)
# print(key.shape)
# print(value.shape)
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
# print(hidden_states.shape)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
# print(hidden_states.shape)
# exit()
return hidden_states
class TemporalAttention(CrossAttention):
def __init__(self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
rotary_emb=None):
super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups)
# relative time positional embeddings
self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet
self.rotary_emb = rotary_emb
# self.rotary_emb = RotaryEmbedding(32)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device)
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
dim = query.shape[-1]
if self.added_kv_proj_dim is not None:
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
else:
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# Do not use xformers for temporal attention
# attention, what we cannot get enough of
# if self._use_memory_efficient_attention_xformers:
# hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# # Some versions of xformers return output in fp32, cast it back to the dtype of the input
# hidden_states = hidden_states.to(query.dtype)
# else:
# if self._slice_size is None or query.shape[0] // self._slice_size == 1:
# hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias)
# else:
# hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None):
if self.upcast_attention:
query = query.float()
key = key.float()
# print('query shape', query.shape)
# print('key shape', key.shape)
# print('value shape', value.shape)
# reshape for adding time positional bais
query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
# print('query shape', query.shape)
# print('key shape', key.shape)
# print('value shape', value.shape)
# torch.baddbmm only accepte 3-D tensor
# https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm
# attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2))
if exists(self.rotary_emb):
query = self.rotary_emb.rotate_queries_or_keys(query)
key = self.rotary_emb.rotate_queries_or_keys(key)
attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key)
# print('attention_scores shape', attention_scores.shape)
# print('time_rel_pos_bias shape', time_rel_pos_bias.shape)
# print('attention_mask shape', attention_mask.shape)
attention_scores = attention_scores + time_rel_pos_bias
# print(attention_scores.shape)
# bert from huggin face
# attention_scores = attention_scores / math.sqrt(self.dim_head)
# # Normalize the attention scores to probabilities.
# attention_probs = nn.functional.softmax(attention_scores, dim=-1)
if attention_mask is not None:
# add attention mask
attention_scores = attention_scores + attention_mask
# vdm
attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach()
# # Mask out future positions (causal mask)
# mask = torch.triu(torch.ones(16, 16), diagonal=1).to(device=attention_scores.device, dtype=attention_scores.dtype) #
# attention_scores.masked_fill_(mask == 1, float('-inf'))
# # # disable the fisrt frame
# mask = torch.zeros(16, 16).to(device=attention_scores.device, dtype=attention_scores.dtype)
# mask[:, :1] = 1
# mask[0, 0] = 0
# attention_scores.masked_fill_(mask == 1, float('-inf'))
# only enable the first frame to internact with others frames
# mask = torch.zeros(16, 16).to(device=attention_scores.device, dtype=attention_scores.dtype)
# mask[:1, 1:] = 1
# attention_scores.masked_fill_(mask == 1, float('-inf'))
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# print(attention_probs[0][0])
# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)
# compute attention output
# hidden_states = torch.matmul(attention_probs, value)
hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value)
# print(hidden_states.shape)
# hidden_states = self.same_batch_dim_to_heads(hidden_states)
hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)')
# print(hidden_states.shape)
# exit()
return hidden_states
class RelativePositionBias(nn.Module):
def __init__(
self,
heads=8,
num_buckets=32,
max_distance=128,
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
ret = 0
n = -relative_position
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, n, device):
q_pos = torch.arange(n, dtype = torch.long, device = device)
k_pos = torch.arange(n, dtype = torch.long, device = device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
class PseudoCrossAttention(CrossAttention):
def forward(self, hidden_states, encoder_hidden_states=None, base_content=None, attention_mask=None, video_length=None):
batch_size, sequence_length, _ = hidden_states.shape
video_length = 17
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
if self.added_kv_proj_dim is not None:
raise NotImplementedError
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = rearrange(key, "(b f) d c -> b f d c", f=video_length).contiguous()
key[:, 1:] = key[:, 1:] + key[:, :1]
key = rearrange(key, "b f d c -> (b f) d c").contiguous()
value = rearrange(value, "(b f) d c -> b f d c", f=video_length).contiguous()
value[:, 1:] = value[:, 1:] + value[:, :1]
value = rearrange(value, "b f d c -> (b f) d c").contiguous()
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# hidden_states = rearrange(hidden_states, '(b f) d c -> b f d c', f=video_length).contiguous()
# hidden_states[:, :1, ...] = base_content
# hidden_states = rearrange(hidden_states, 'b f d c -> (b f) d c')
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
\ No newline at end of file
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.embeddings import ImagePositionalEmbeddings
from diffusers.utils import BaseOutput, deprecate
from diffusers.models.embeddings import PatchEmbed
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
from diffusers.models.modeling_utils import ModelMixin
from einops import rearrange, repeat
try:
from attention import BasicTransformerBlock
except:
from .attention import BasicTransformerBlock
@dataclass
class Transformer3DModelOutput(BaseOutput):
"""
The output of [`Transformer2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: torch.FloatTensor
class Transformer3DModel(ModelMixin, ConfigMixin):
"""
A 2D Transformer model for image-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
num_vector_embeds (`int`, *optional*):
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states.
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
rotary_emb=None,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None
self.is_input_patches = in_channels is not None and patch_size is not None
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
deprecation_message = (
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
)
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
norm_type = "ada_norm"
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is None."
)
elif self.is_input_vectorized and self.is_input_patches:
raise ValueError(
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
" sure that either `num_vector_embeds` or `num_patches` is None."
)
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
raise ValueError(
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
)
# 2. Define input layers
if self.is_input_continuous:
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
else:
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
self.height = sample_size
self.width = sample_size
self.num_vector_embeds = num_vector_embeds
self.num_latent_pixels = self.height * self.width
self.latent_image_embedding = ImagePositionalEmbeddings(
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
)
elif self.is_input_patches:
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
self.height = sample_size
self.width = sample_size
self.patch_size = patch_size
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
rotary_emb=rotary_emb,
)
for d in range(num_layers)
]
)
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
else:
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
elif self.is_input_patches:
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
use_image_num=None,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
if self.is_input_continuous: # True
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
if self.training and use_image_num != 0:
video_length = hidden_states.shape[2] - use_image_num
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
encoder_hidden_states_length = encoder_hidden_states.shape[1]
encoder_hidden_states_video = encoder_hidden_states[:, :encoder_hidden_states_length - use_image_num, ...]
encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous()
encoder_hidden_states_image = encoder_hidden_states[:, encoder_hidden_states_length - use_image_num:, ...]
encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
encoder_hidden_states = rearrange(encoder_hidden_states, 'b m n c -> (b m) n c').contiguous()
else:
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
hidden_states = self.pos_embed(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
video_length=video_length,
use_image_num=use_image_num,
)
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
logits = self.out(hidden_states)
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
elif self.is_input_patches:
# TODO: cleanup!
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
# unpatchify
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous()
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment