Commit 1f5da520 authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3144 failed with stages
in 0 seconds
需要用到xformers,所以使用的镜像是 image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10
dtk25.04.1和dtk25.04.2的镜像中没有适配安装xformers
```
# 拉取镜像
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10
# 创建容器
docker run -it --network=host --name=dtk24043_torch23 -v /opt/hyhal:/opt/hyhal:ro -v /usr/local/hyhal:/usr/local/hyhal:ro -v /public:/public:ro --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=128G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-ubuntu22.04-dtk24.04.3-py3.10
```
```
git clone https://github.com/NJU-PCALab/STAR.git
cd STAR
pip install -r requirements.txt # 安装环境中缺少的依赖,已有的进行注释,open-clip-torch要安装指定版本!!!
# 安装diffusers
git clone -b v0.30.0-release http://developer.sourcefind.cn/codes/OpenDAS/diffusers.git
cd diffusers/
python3 setup.py install
sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y
```
#### Step 1: 下载预训练模型 [HuggingFace](https://huggingface.co/SherryX/STAR).
We provide two versions for I2VGen-XL-based model, `heavy_deg.pt` for heavy degraded videos and `light_deg.pt` for light degraded videos (e.g., the low-resolution video downloaded from video websites).
You can put the weight into `pretrained_weight/`.
#### Step 2: 准备测试数据(pr中有,此步跳过)
You can put the testing videos in the `input/video/`.
As for the prompt, there are three options: 1. No prompt. 2. Automatically generate a prompt (e.g., [using Pllava](https://github.com/hpcaitech/Open-Sora/tree/main/tools/caption#pllava-captioning)). 3. Manually write the prompt. You can put the txt file in the `input/text/`.
#### Step 3: 修改为自己的本地路径
You need to change the paths in `video_super_resolution/scripts/inference_sr.sh` to your local corresponding paths, including `video_folder_path`, `txt_file_path`, `model_path`, and `save_dir`.
#### Step 4: 运行推理命令
```
bash video_super_resolution/scripts/inference_sr.sh
```
\ No newline at end of file
### CogVideoX-based Model Inference
#### Step 1: Install the requirements
```
conda create -n star_cog python=3.10
conda activate star_cog
cd cogvideox-based/sat
pip install -r requirements.txt
```
#### Step 2: Download the pretrained model.
Download STAR from [HuggingFace](https://huggingface.co/SherryX/STAR).
Download VAE and T5 Encoder following this [instruction](https://github.com/THUDM/CogVideo/blob/main/sat/README_zh.md#cogvideox15-%E6%A8%A1%E5%9E%8B).
#### Step 3: Prepare testing data
You can put the testing videos in the `input/video/`.
As for the prompt, there are three options: 1. No prompt. 2. Automatically generate a prompt [using Pllava](https://github.com/hpcaitech/Open-Sora/tree/main/tools/caption#pllava-captioning). 3. Manually write the prompt. You can put the txt file in the `input/text/`.
#### Step 4: Change the cogfigs
You need to update the paths in `cogvideox-based/sat/configs/cogvideox_5b/cogvideox_5b_infer_sr.yaml` to match your local setup, including `load`, `output_dir`, `model_dir` of conditioner_config and `ckpt_path` of first_stage_config. Additionally, update the `test_dataset` path in sample_sr.py.
#### Step 5: Replace the transformer.py in sat packpage
Replace the `/cogvideo/lib/python3.9/site-packages/sat/model/transformer.py` in your enviroment with our provided [transformer.py](https://github.com/NJU-PCALab/STAR/blob/main/cogvideox-based/transformer.py).
#### Step 6: Running inference command
```
bash inference_sr.sh
```
"""
This script demonstrates how to generate a video from a text prompt using CogVideoX with 🤗Huggingface Diffusers Pipeline.
Note:
This script requires the `diffusers>=0.30.0` library to be installed.
If the video exported using OpenCV appears “completely green” and cannot be viewed, lease switch to a different player to watch it. This is a normal phenomenon.
Run the script:
$ python cli_demo.py --prompt "A girl ridding a bike." --model_path THUDM/CogVideoX-2b
"""
import argparse
import tempfile
from typing import Union, List
import PIL
import imageio
import numpy as np
import torch
from diffusers import CogVideoXPipeline
def export_to_video_imageio(
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
) -> str:
"""
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
"""
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
if isinstance(video_frames[0], PIL.Image.Image):
video_frames = [np.array(frame) for frame in video_frames]
with imageio.get_writer(output_video_path, fps=fps) as writer:
for frame in video_frames:
writer.append_data(frame)
return output_video_path
def generate_video(
prompt: str,
model_path: str,
output_path: str = "./output.mp4",
num_inference_steps: int = 50,
guidance_scale: float = 6.0,
num_videos_per_prompt: int = 1,
device: str = "cuda",
dtype: torch.dtype = torch.float16,
):
"""
Generates a video based on the given prompt and saves it to the specified path.
Parameters:
- prompt (str): The description of the video to be generated.
- model_path (str): The path of the pre-trained model to be used.
- output_path (str): The path where the generated video will be saved.
- num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
- guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
- num_videos_per_prompt (int): Number of videos to generate per prompt.
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
- dtype (torch.dtype): The data type for computation (default is torch.float16).
"""
# Load the pre-trained CogVideoX pipeline with the specified precision (float16) and move it to the specified device
# add device_map="balanced" in the from_pretrained function and remove
# `pipe.enable_model_cpu_offload()` to enable Multi GPUs (2 or more and each one must have more than 20GB memory) inference.
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
# Encode the prompt to get the prompt embeddings
prompt_embeds, _ = pipe.encode_prompt(
prompt=prompt, # The textual description for video generation
negative_prompt=None, # The negative prompt to guide the video generation
do_classifier_free_guidance=True, # Whether to use classifier-free guidance
num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
max_sequence_length=226, # Maximum length of the sequence, must be 226
device=device, # Device to use for computation
dtype=dtype, # Data type for computation
)
# Generate the video frames using the pipeline
video = pipe(
num_inference_steps=num_inference_steps, # Number of inference steps
guidance_scale=guidance_scale, # Guidance scale for classifier-free guidance
prompt_embeds=prompt_embeds, # Encoded prompt embeddings
negative_prompt_embeds=torch.zeros_like(prompt_embeds), # Not Supported negative prompt
).frames[0]
# Export the generated frames to a video file. fps must be 8
export_to_video_imageio(video, output_path, fps=8)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
parser.add_argument(
"--model_path", type=str, default="THUDM/CogVideoX-2b", help="The path of the pre-trained model to be used"
)
parser.add_argument(
"--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
)
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
)
parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
parser.add_argument(
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
)
parser.add_argument(
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
)
args = parser.parse_args()
# Convert dtype argument to torch.dtype, NOT suggest BF16.
dtype = torch.float16 if args.dtype == "float16" else torch.float32
# main function to generate video.
generate_video(
prompt=args.prompt,
model_path=args.model_path,
output_path=args.output_path,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
num_videos_per_prompt=args.num_videos_per_prompt,
device=args.device,
dtype=dtype,
)
"""
This script demonstrates how to encode video frames using a pre-trained CogVideoX model with 🤗 Huggingface Diffusers.
Note:
This script requires the `diffusers>=0.30.0` library to be installed.
If the video appears “completely green” and cannot be viewed, please switch to a different player to watch it. This is a normal phenomenon.
Cost 71GB of GPU memory for encoding a 6s video at 720p resolution.
Run the script:
$ python cli_demo.py --model_path THUDM/CogVideoX-2b --video_path path/to/video.mp4 --output_path path/to/output
"""
import argparse
import torch
import imageio
import numpy as np
from diffusers import AutoencoderKLCogVideoX
from torchvision import transforms
def vae_demo(model_path, video_path, dtype, device):
"""
Loads a pre-trained AutoencoderKLCogVideoX model and encodes the video frames.
Parameters:
- model_path (str): The path to the pre-trained model.
- video_path (str): The path to the video file.
- dtype (torch.dtype): The data type for computation.
- device (str): The device to use for computation (e.g., "cuda" or "cpu").
Returns:
- torch.Tensor: The encoded video frames.
"""
# Load the pre-trained model
model = AutoencoderKLCogVideoX.from_pretrained(model_path, torch_dtype=dtype).to(device)
# Load video frames
video_reader = imageio.get_reader(video_path, "ffmpeg")
frames = []
for frame in video_reader:
frames.append(frame)
video_reader.close()
# Transform frames to Tensor
transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
frames_tensor = torch.stack([transform(frame) for frame in frames]).to(device)
# Add batch dimension and reshape to [1, 3, 49, 480, 720]
frames_tensor = frames_tensor.permute(1, 0, 2, 3).unsqueeze(0).to(dtype).to(device)
# Run the model with Encoder and Decoder
with torch.no_grad():
output = model(frames_tensor)
return output
def save_video(tensor, output_path):
"""
Saves the encoded video frames to a video file.
Parameters:
- tensor (torch.Tensor): The encoded video frames.
- output_path (str): The path to save the output video.
"""
# Remove batch dimension and permute back to [49, 480, 720, 3]
frames = tensor[0].squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
# Clip values to [0, 1] and convert to uint8
frames = np.clip(frames, 0, 1)
frames = (frames * 255).astype(np.uint8)
# Save frames to video
writer = imageio.get_writer(output_path + "/output.mp4", fps=30)
for frame in frames:
writer.append_data(frame)
writer.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert a CogVideoX model to Diffusers")
parser.add_argument("--model_path", type=str, required=True, help="The path to the CogVideoX model")
parser.add_argument("--video_path", type=str, required=True, help="The path to the video file")
parser.add_argument("--output_path", type=str, default="./", help="The path to save the output video")
parser.add_argument(
"--dtype", type=str, default="float16", help="The data type for computation (e.g., 'float16' or 'float32')"
)
parser.add_argument(
"--device", type=str, default="cuda", help="The device to use for computation (e.g., 'cuda' or 'cpu')"
)
args = parser.parse_args()
# Set device and dtype
device = torch.device(args.device)
dtype = torch.float16 if args.dtype == "float16" else torch.float32
output = vae_demo(args.model_path, args.video_path, dtype, device)
save_video(output, args.output_path)
"""
The CogVideoX model is pre-trained and fine-tuned using longer and more detailed prompts.Therefore, it requires highly granular and detailed prompts as input.This script aims to transform user inputs into executable inputs for CogVideoX, enabling superior video generation.
This step is not mandatory; the model will still function correctly and without errors even if the prompts are not refined using this script. However, we strongly recommend using it to ensure the generation of high-quality videos.
Note:
Please set the OPENAI_API_KEY and OPENAI_BASE_URL(if needed) environment variable to your OpenAI API key before running this script.
Run the script:
$ python convert_demo.py --prompt "A girl ridding a bike." # Using with OpenAI's API
"""
import argparse
from openai import OpenAI
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
There are a few rules to follow:
You will only ever output a single video description per user request.
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
"""
def convert_prompt(prompt: str, retry_times: int = 3):
"""
Convert a prompt to a format that can be used by the model for inference
"""
client = OpenAI()
text = prompt.strip()
for i in range(retry_times):
response = client.chat.completions.create(
messages=[
{"role": "system", "content": f"{sys_prompt}"},
{
"role": "user",
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " a girl is on the beach"',
},
{
"role": "assistant",
"content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
},
{
"role": "user",
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A man jogging on a football field"',
},
{
"role": "assistant",
"content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
},
{
"role": "user",
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
},
{
"role": "assistant",
"content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
},
{
"role": "user",
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: " {text} "',
},
],
model="glm-4-0520", # glm-4-0520 and gpt-4o have be tested
temperature=0.01,
top_p=0.7,
stream=False,
max_tokens=250,
)
if response.choices:
return response.choices[0].message.content
return prompt
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, required=True, help="Prompt to convert")
parser.add_argument("--retry_times", type=int, default=3, help="Number of times to retry the conversion")
args = parser.parse_args()
converted_prompt = convert_prompt(args.prompt, args.retry_times)
print(converted_prompt)
import os
import tempfile
import threading
import time
import gradio as gr
import numpy as np
import torch
from diffusers import CogVideoXPipeline
from datetime import datetime, timedelta
from openai import OpenAI
import imageio
import moviepy.editor as mp
from typing import List, Union
import PIL
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=dtype)
pipe.enable_model_cpu_offload()
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
There are a few rules to follow:
You will only ever output a single video description per user request.
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
"""
def export_to_video_imageio(
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
) -> str:
"""
Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
"""
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
if isinstance(video_frames[0], PIL.Image.Image):
video_frames = [np.array(frame) for frame in video_frames]
with imageio.get_writer(output_video_path, fps=fps) as writer:
for frame in video_frames:
writer.append_data(frame)
return output_video_path
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
if not os.environ.get("OPENAI_API_KEY"):
return prompt
client = OpenAI()
text = prompt.strip()
for i in range(retry_times):
response = client.chat.completions.create(
messages=[
{"role": "system", "content": sys_prompt},
{"role": "user",
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"'},
{"role": "assistant",
"content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance."},
{"role": "user",
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"'},
{"role": "assistant",
"content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field."},
{"role": "user",
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"'},
{"role": "assistant",
"content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background."},
{"role": "user",
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"'},
],
model="glm-4-0520",
temperature=0.01,
top_p=0.7,
stream=False,
max_tokens=250,
)
if response.choices:
return response.choices[0].message.content
return prompt
def infer(
prompt: str,
num_inference_steps: int,
guidance_scale: float,
progress=gr.Progress(track_tqdm=True)
):
torch.cuda.empty_cache()
prompt_embeds, _ = pipe.encode_prompt(
prompt=prompt,
negative_prompt=None,
do_classifier_free_guidance=True,
num_videos_per_prompt=1,
max_sequence_length=226,
device=device,
dtype=dtype,
)
video = pipe(
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=torch.zeros_like(prompt_embeds),
).frames[0]
return video
def save_video(tensor):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
video_path = f"./output/{timestamp}.mp4"
os.makedirs(os.path.dirname(video_path), exist_ok=True)
export_to_video_imageio(tensor[1:], video_path)
return video_path
def convert_to_gif(video_path):
clip = mp.VideoFileClip(video_path)
clip = clip.set_fps(8)
clip = clip.resize(height=240)
gif_path = video_path.replace('.mp4', '.gif')
clip.write_gif(gif_path, fps=8)
return gif_path
def delete_old_files():
while True:
now = datetime.now()
cutoff = now - timedelta(minutes=10)
output_dir = './output'
for filename in os.listdir(output_dir):
file_path = os.path.join(output_dir, filename)
if os.path.isfile(file_path):
file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
if file_mtime < cutoff:
os.remove(file_path)
time.sleep(600) # Sleep for 10 minutes
threading.Thread(target=delete_old_files, daemon=True).start()
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
CogVideoX-2B Huggingface Space🤗
</div>
<div style="text-align: center;">
<a href="https://huggingface.co/THUDM/CogVideoX-2b">🤗 Model Hub</a> |
<a href="https://github.com/THUDM/CogVideo">🌐 Github</a>
</div>
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
⚠️ This demo is for academic research and experiential use only.
Users should strictly adhere to local laws and ethics.
</div>
""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
with gr.Row():
gr.Markdown(
"✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one.")
enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
with gr.Column():
gr.Markdown("**Optional Parameters** (default values are recommended)<br>"
"Turn Inference Steps larger if you want more detailed video, but it will be slower.<br>"
"50 steps are recommended for most cases. will cause 120 seconds for inference.<br>")
with gr.Row():
num_inference_steps = gr.Number(label="Inference Steps", value=50)
guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
generate_button = gr.Button("🎬 Generate Video")
with gr.Column():
video_output = gr.Video(label="CogVideoX Generate Video", width=720, height=480)
with gr.Row():
download_video_button = gr.File(label="📥 Download Video", visible=False)
download_gif_button = gr.File(label="📥 Download GIF", visible=False)
gr.Markdown("""
<table border="1" style="width: 100%; text-align: left; margin-top: 20px;">
<tr>
<th>Prompt</th>
<th>Video URL</th>
<th>Inference Steps</th>
<th>Guidance Scale</th>
</tr>
<tr>
<td>A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.</td>
<td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/1.mp4">Video 1</a></td>
<td>50</td>
<td>6</td>
</tr>
<tr>
<td>The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from it’s tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds.</td>
<td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/2.mp4">Video 2</a></td>
<td>50</td>
<td>6</td>
</tr>
<tr>
<td>A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall.</td>
<td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/3.mp4">Video 3</a></td>
<td>50</td>
<td>6</td>
</tr>
<tr>
<td>In the haunting backdrop of a war-torn city, where ruins and crumbled walls tell a story of devastation, a poignant close-up frames a young girl. Her face is smudged with ash, a silent testament to the chaos around her. Her eyes glistening with a mix of sorrow and resilience, capturing the raw emotion of a world that has lost its innocence to the ravages of conflict.</td>
<td><a href="https://github.com/THUDM/CogVideo/raw/main/resources/videos/4.mp4">Video 4</a></td>
<td>50</td>
<td>6</td>
</tr>
</table>
""")
def generate(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
video_path = save_video(tensor)
video_update = gr.update(visible=True, value=video_path)
gif_path = convert_to_gif(video_path)
gif_update = gr.update(visible=True, value=gif_path)
return video_path, video_update, gif_update
def enhance_prompt_func(prompt):
return convert_prompt(prompt, retry_times=1)
generate_button.click(
generate,
inputs=[prompt, num_inference_steps, guidance_scale],
outputs=[video_output, download_video_button, download_gif_button]
)
enhance_button.click(
enhance_prompt_func,
inputs=[prompt],
outputs=[prompt]
)
if __name__ == "__main__":
demo.launch(server_name="127.0.0.1", server_port=7870, share=True)
"""
This script is used to create a Streamlit web application for generating videos using the CogVideoX model.
Run the script using Streamlit:
$ export OPENAI_API_KEY=your OpenAI Key or ZhiupAI Key
$ export OPENAI_BASE_URL=https://open.bigmodel.cn/api/paas/v4/ # using with ZhipuAI, Not using this when using OpenAI
$ streamlit run web_demo.py
"""
import base64
import json
import os
import time
from datetime import datetime
from typing import List
import imageio
import numpy as np
import streamlit as st
import torch
from convert_demo import convert_prompt
from diffusers import CogVideoXPipeline
model_path: str = "THUDM/CogVideoX-2b"
# Load the model at the start
@st.cache_resource
def load_model(model_path: str, dtype: torch.dtype, device: str) -> CogVideoXPipeline:
"""
Load the CogVideoX model.
Args:
- model_path (str): Path to the model.
- dtype (torch.dtype): Data type for model.
- device (str): Device to load the model on.
Returns:
- CogVideoXPipeline: Loaded model pipeline.
"""
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
return pipe
# Define a function to generate video based on the provided prompt and model path
def generate_video(
pipe: CogVideoXPipeline,
prompt: str,
num_inference_steps: int = 50,
guidance_scale: float = 6.0,
num_videos_per_prompt: int = 1,
device: str = "cuda",
dtype: torch.dtype = torch.float16,
) -> List[np.ndarray]:
"""
Generate a video based on the provided prompt and model path.
Args:
- pipe (CogVideoXPipeline): The pipeline for generating videos.
- prompt (str): Text prompt for video generation.
- num_inference_steps (int): Number of inference steps.
- guidance_scale (float): Guidance scale for generation.
- num_videos_per_prompt (int): Number of videos to generate per prompt.
- device (str): Device to run the generation on.
- dtype (torch.dtype): Data type for the model.
Returns:
- List[np.ndarray]: Generated video frames.
"""
prompt_embeds, _ = pipe.encode_prompt(
prompt=prompt,
negative_prompt=None,
do_classifier_free_guidance=True,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=226,
device=device,
dtype=dtype,
)
pipe.enable_model_cpu_offload()
# Generate video
video = pipe(
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=torch.zeros_like(prompt_embeds),
).frames[0]
return video
def save_video(video: List[np.ndarray], path: str, fps: int = 8) -> None:
"""
Save the generated video to a file.
Args:
- video (List[np.ndarray]): Video frames.
- path (str): Path to save the video.
- fps (int): Frames per second for the video.
"""
# Remove the first frame
video = video[1:]
writer = imageio.get_writer(path, fps=fps, codec="libx264")
for frame in video:
np_frame = np.array(frame)
writer.append_data(np_frame)
writer.close()
def save_metadata(
prompt: str,
converted_prompt: str,
num_inference_steps: int,
guidance_scale: float,
num_videos_per_prompt: int,
path: str,
) -> None:
"""
Save metadata to a JSON file.
Args:
- prompt (str): Original prompt.
- converted_prompt (str): Converted prompt.
- num_inference_steps (int): Number of inference steps.
- guidance_scale (float): Guidance scale.
- num_videos_per_prompt (int): Number of videos per prompt.
- path (str): Path to save the metadata.
"""
metadata = {
"prompt": prompt,
"converted_prompt": converted_prompt,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"num_videos_per_prompt": num_videos_per_prompt,
}
with open(path, "w") as f:
json.dump(metadata, f, indent=4)
def main() -> None:
"""
Main function to run the Streamlit web application.
"""
st.set_page_config(page_title="CogVideoX-Demo", page_icon="🎥", layout="wide")
st.write("# CogVideoX 🎥")
dtype: torch.dtype = torch.float16
device: str = "cuda"
global pipe
pipe = load_model(model_path, dtype, device)
with st.sidebar:
st.info("It will take some time to generate a video (~90 seconds per videos in 50 steps).", icon="ℹ️")
num_inference_steps: int = st.number_input("Inference Steps", min_value=1, max_value=100, value=50)
guidance_scale: float = st.number_input("Guidance Scale", min_value=0.0, max_value=20.0, value=6.0)
num_videos_per_prompt: int = st.number_input("Videos per Prompt", min_value=1, max_value=10, value=1)
share_links_container = st.empty()
prompt: str = st.chat_input("Prompt")
if prompt:
# Not Necessary, Suggestions
with st.spinner("Refining prompts..."):
converted_prompt = convert_prompt(prompt=prompt, retry_times=1)
if converted_prompt is None:
st.error("Failed to Refining the prompt, Using origin one.")
st.info(f"**Origin prompt:** \n{prompt} \n \n**Convert prompt:** \n{converted_prompt}")
torch.cuda.empty_cache()
with st.spinner("Generating Video..."):
start_time = time.time()
video_paths = []
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = f"./output/{timestamp}"
os.makedirs(output_dir, exist_ok=True)
metadata_path = os.path.join(output_dir, "config.json")
save_metadata(
prompt, converted_prompt, num_inference_steps, guidance_scale, num_videos_per_prompt, metadata_path
)
for i in range(num_videos_per_prompt):
video_path = os.path.join(output_dir, f"output_{i + 1}.mp4")
video = generate_video(
pipe, converted_prompt or prompt, num_inference_steps, guidance_scale, 1, device, dtype
)
save_video(video, video_path, fps=8)
video_paths.append(video_path)
with open(video_path, "rb") as video_file:
video_bytes: bytes = video_file.read()
st.video(video_bytes, autoplay=True, loop=True, format="video/mp4")
torch.cuda.empty_cache()
used_time: float = time.time() - start_time
st.success(f"Videos generated in {used_time:.2f} seconds.")
# Create download links in the sidebar
with share_links_container:
st.sidebar.write("### Download Links:")
for video_path in video_paths:
video_name = os.path.basename(video_path)
with open(video_path, "rb") as f:
video_bytes: bytes = f.read()
b64_video = base64.b64encode(video_bytes).decode()
href = f'<a href="data:video/mp4;base64,{b64_video}" download="{video_name}">Download {video_name}</a>'
st.sidebar.markdown(href, unsafe_allow_html=True)
if __name__ == "__main__":
main()
import argparse
import os
import torch
import json
import warnings
import omegaconf
from omegaconf import OmegaConf
from sat.helpers import print_rank0
from sat import mpu
from sat.arguments import set_random_seed
from sat.arguments import add_training_args, add_evaluation_args, add_data_args
import torch.distributed
def add_model_config_args(parser):
"""Model arguments"""
group = parser.add_argument_group("model", "model configuration")
group.add_argument("--base", type=str, nargs="*", help="config for input and saving")
group.add_argument(
"--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert."
)
group.add_argument("--force-pretrain", action="store_true")
group.add_argument("--device", type=int, default=-1)
group.add_argument("--debug", action="store_true")
group.add_argument("--log-image", type=bool, default=True)
return parser
def add_sampling_config_args(parser):
"""Sampling configurations"""
group = parser.add_argument_group("sampling", "Sampling Configurations")
group.add_argument("--output-dir", type=str, default="samples")
group.add_argument("--input-dir", type=str, default=None)
group.add_argument("--input-type", type=str, default="cli")
group.add_argument("--input-file", type=str, default="input.txt")
group.add_argument("--final-size", type=int, default=2048)
group.add_argument("--sdedit", action="store_true")
group.add_argument("--grid-num-rows", type=int, default=1)
group.add_argument("--force-inference", action="store_true")
group.add_argument("--lcm_steps", type=int, default=None)
group.add_argument("--sampling-num-frames", type=int, default=32)
group.add_argument("--sampling-fps", type=int, default=8)
group.add_argument("--only-save-latents", type=bool, default=False)
group.add_argument("--only-log-video-latents", type=bool, default=False)
group.add_argument("--latent-channels", type=int, default=32)
group.add_argument("--image2video", action="store_true")
return parser
def get_args(args_list=None, parser=None):
"""Parse all the args."""
if parser is None:
parser = argparse.ArgumentParser(description="sat")
else:
assert isinstance(parser, argparse.ArgumentParser)
parser = add_model_config_args(parser)
parser = add_sampling_config_args(parser)
parser = add_training_args(parser)
parser = add_evaluation_args(parser)
parser = add_data_args(parser)
import deepspeed
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args(args_list)
args = process_config_to_args(args)
if not args.train_data:
print_rank0("No training data specified", level="WARNING")
assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set."
if args.train_iters is None and args.epochs is None:
args.train_iters = 10000 # default 10k iters
print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING")
args.cuda = torch.cuda.is_available()
args.rank = int(os.getenv("RANK", "0"))
args.world_size = int(os.getenv("WORLD_SIZE", "1"))
if args.local_rank is None:
args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun
if args.device == -1:
if torch.cuda.device_count() == 0:
args.device = "cpu"
elif args.local_rank is not None:
args.device = args.local_rank
else:
args.device = args.rank % torch.cuda.device_count()
if args.local_rank != args.device and args.mode != "inference":
raise ValueError(
"LOCAL_RANK (default 0) and args.device inconsistent. "
"This can only happens in inference mode. "
"Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. "
)
if args.rank == 0:
print_rank0("using world size: {}".format(args.world_size))
if args.train_data_weights is not None:
assert len(args.train_data_weights) == len(args.train_data)
if args.mode != "inference": # training with deepspeed
args.deepspeed = True
if args.deepspeed_config is None: # not specified
deepspeed_config_path = os.path.join(
os.path.dirname(__file__), "training", f"deepspeed_zero{args.zero_stage}.json"
)
with open(deepspeed_config_path) as file:
args.deepspeed_config = json.load(file)
override_deepspeed_config = True
else:
override_deepspeed_config = False
assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16."
if args.zero_stage > 0 and not args.fp16 and not args.bf16:
print_rank0("Automatically set fp16=True to use ZeRO.")
args.fp16 = True
args.bf16 = False
if args.deepspeed:
if args.checkpoint_activations:
args.deepspeed_activation_checkpointing = True
else:
args.deepspeed_activation_checkpointing = False
if args.deepspeed_config is not None:
deepspeed_config = args.deepspeed_config
if override_deepspeed_config: # not specify deepspeed_config, use args
if args.fp16:
deepspeed_config["fp16"]["enabled"] = True
elif args.bf16:
deepspeed_config["bf16"]["enabled"] = True
deepspeed_config["fp16"]["enabled"] = False
else:
deepspeed_config["fp16"]["enabled"] = False
deepspeed_config["train_micro_batch_size_per_gpu"] = args.batch_size
deepspeed_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
optimizer_params_config = deepspeed_config["optimizer"]["params"]
optimizer_params_config["lr"] = args.lr
optimizer_params_config["weight_decay"] = args.weight_decay
else: # override args with values in deepspeed_config
if args.rank == 0:
print_rank0("Will override arguments with manually specified deepspeed_config!")
if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]:
args.fp16 = True
else:
args.fp16 = False
if "bf16" in deepspeed_config and deepspeed_config["bf16"]["enabled"]:
args.bf16 = True
else:
args.bf16 = False
if "train_micro_batch_size_per_gpu" in deepspeed_config:
args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"]
if "gradient_accumulation_steps" in deepspeed_config:
args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"]
else:
args.gradient_accumulation_steps = None
if "optimizer" in deepspeed_config:
optimizer_params_config = deepspeed_config["optimizer"].get("params", {})
args.lr = optimizer_params_config.get("lr", args.lr)
args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay)
args.deepspeed_config = deepspeed_config
# initialize distributed and random seed because it always seems to be necessary.
initialize_distributed(args)
args.seed = args.seed + mpu.get_data_parallel_rank()
set_random_seed(args.seed)
return args
def initialize_distributed(args):
"""Initialize torch.distributed."""
if torch.distributed.is_initialized():
if mpu.model_parallel_is_initialized():
if args.model_parallel_size != mpu.get_model_parallel_world_size():
raise ValueError(
"model_parallel_size is inconsistent with prior configuration."
"We currently do not support changing model_parallel_size."
)
return False
else:
if args.model_parallel_size > 1:
warnings.warn(
"model_parallel_size > 1 but torch.distributed is not initialized via SAT."
"Please carefully make sure the correctness on your own."
)
mpu.initialize_model_parallel(args.model_parallel_size)
return True
# the automatic assignment of devices has been moved to arguments.py
if args.device == "cpu":
pass
else:
torch.cuda.set_device(args.device)
# Call the init process
init_method = "tcp://"
args.master_ip = os.getenv("MASTER_ADDR", "localhost")
if args.world_size == 1:
from sat.helpers import get_free_port
default_master_port = str(get_free_port())
else:
default_master_port = "6000"
args.master_port = os.getenv("MASTER_PORT", default_master_port)
init_method += args.master_ip + ":" + args.master_port
torch.distributed.init_process_group(
backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
)
# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
# Set vae context parallel group equal to model parallel group
from sgm.util import set_context_parallel_group, initialize_context_parallel
if args.model_parallel_size <= 2:
set_context_parallel_group(args.model_parallel_size, mpu.get_model_parallel_group())
else:
initialize_context_parallel(2)
# mpu.initialize_model_parallel(1)
# Optional DeepSpeed Activation Checkpointing Features
if args.deepspeed:
import deepspeed
deepspeed.init_distributed(
dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method
)
# # It seems that it has no negative influence to configure it even without using checkpointing.
# deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers)
else:
# in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout.
try:
import deepspeed
from deepspeed.runtime.activation_checkpointing.checkpointing import (
_CUDA_RNG_STATE_TRACKER,
_MODEL_PARALLEL_RNG_TRACKER_NAME,
)
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1) # default seed 1
except Exception as e:
from sat.helpers import print_rank0
print_rank0(str(e), level="DEBUG")
return True
def process_config_to_args(args):
"""Fetch args from only --base"""
configs = [OmegaConf.load(cfg) for cfg in args.base]
config = OmegaConf.merge(*configs)
args_config = config.pop("args", OmegaConf.create())
for key in args_config:
if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig):
arg = OmegaConf.to_object(args_config[key])
else:
arg = args_config[key]
if hasattr(args, key):
setattr(args, key, arg)
if "model" in config:
model_config = config.pop("model", OmegaConf.create())
args.model_config = model_config
if "deepspeed" in config:
deepspeed_config = config.pop("deepspeed", OmegaConf.create())
args.deepspeed_config = OmegaConf.to_object(deepspeed_config)
if "data" in config:
data_config = config.pop("data", OmegaConf.create())
args.data_config = data_config
return args
'''
# --------------------------------------------------------------------------------
# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
# --------------------------------------------------------------------------------
'''
import torch
from PIL import Image
from torch import Tensor
from torch.nn import functional as F
from torchvision.transforms import ToTensor, ToPILImage
from einops import rearrange
def adain_color_fix(target: Image, source: Image):
# Convert images to tensors (b, t, c, h, w)
target, source = target.squeeze(0), source.squeeze(0)
source = (source + 1) / 2
# Apply adaptive instance normalization
result_tensor_list = []
for i in range(0, target.shape[0]):
result_tensor_list.append(adaptive_instance_normalization(target[i].unsqueeze(0), source[i].unsqueeze(0)))
# Convert tensor back to image
result_tensor = torch.cat(result_tensor_list, dim=0).clamp_(0.0, 1.0)
# result_video = rearrange(result_tensor, "T C H W -> T H W C") * 255
result_video = result_tensor.unsqueeze(0)
return result_video
def wavelet_color_fix(target, source):
# Convert images to tensors
target = rearrange(target, 'T H W C -> T C H W') / 255
source = (source + 1) / 2
# Apply wavelet reconstruction
result_tensor_list = []
for i in range(0, target.shape[0]):
result_tensor_list.append(wavelet_reconstruction(target[i].unsqueeze(0), source[i].unsqueeze(0)))
# Convert tensor back to image
result_tensor = torch.cat(result_tensor_list, dim=0).clamp_(0.0, 1.0)
result_video = rearrange(result_tensor, "T C H W -> T H W C") * 255
return result_video
def calc_mean_std(feat: Tensor, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq
\ No newline at end of file
args:
latent_channels: 16
mode: inference
load: '/mnt/bn/videodataset/VSR/pretrained_models/cogvideox/transformer'
batch_size: 1
input_type: txt
input_file: ../../input/text/prompt.txt
sampling_num_frames: 7 # Must be 13, 11 or 9
sampling_fps: 8 # (invalid)
# fp16: True
bf16: True
output_dir: ./output
force_inference: True
model:
scale_factor: 0.7 # different from cogvideox_2b_infer.yaml
disable_first_stage_autocast: true
not_trainable_prefixes: ['all'] # Using Lora
log_keys:
- txt
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
quantize_c_noise: False
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VideoScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
network_config:
target: dit_video_concat.DiffusionTransformer
params:
time_embed_dim: 512
elementwise_affine: True
num_frames: 49
time_compressed_rate: 4
latent_width: 90
latent_height: 60
num_layers: 42 # different from cogvideox_2b_infer.yaml
patch_size: 2
in_channels: 16
out_channels: 16
hidden_size: 3072 # different from cogvideox_2b_infer.yaml
adm_in_channels: 256
num_attention_heads: 48 # different from cogvideox_2b_infer.yaml
transformer_args:
checkpoint_activations: True
vocab_size: 1
max_sequence_length: 64
layernorm_order: pre
skip_init: false
model_parallel_size: 1
is_decoder: false
modules:
pos_embed_config:
target: dit_video_concat.Rotary3DPositionEmbeddingMixin # different from cogvideox_2b_infer.yaml
params:
hidden_size_head: 64
text_length: 226
lora_config: # Using Lora
target: sat.model.finetune.lora2.LoraMixin
params:
r: 512
patch_embed_config:
target: dit_video_concat.ImagePatchEmbeddingMixin
params:
text_hidden_size: 4096
adaln_layer_config:
target: dit_video_concat.AdaLNMixin
params:
qk_ln: True
final_layer_config:
target: dit_video_concat.FinalLayerMixin
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: false
input_key: txt
ucg_rate: 0.1
target: sgm.modules.encoders.modules.FrozenT5Embedder
params:
model_dir: "/mnt/bn/videodataset/VSR/pretrained_models/cogvideox/t5-v1_1-xxl"
max_length: 226
first_stage_config:
target: vae_modules.autoencoder.VideoAutoencoderInferenceWrapper
params:
cp_size: 1
ckpt_path: "/mnt/bn/videodataset/VSR/pretrained_models/cogvideox/vae/3d-vae.pt"
ignore_keys: [ 'loss' ]
loss_config:
target: torch.nn.Identity
regularizer_config:
target: vae_modules.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: vae_modules.cp_enc_dec.ContextParallelEncoder3D
params:
double_z: true
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 2, 4 ]
attn_resolutions: [ ]
num_res_blocks: 3
dropout: 0.0
gather_norm: True
decoder_config:
target: vae_modules.cp_enc_dec.ContextParallelDecoder3D
params:
double_z: True
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 2, 4 ]
attn_resolutions: [ ]
num_res_blocks: 3
dropout: 0.0
gather_norm: False
loss_fn_config:
target: sgm.modules.diffusionmodules.loss.SRDiffusionLoss
params:
offset_noise_level: 0
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
params:
uniform_sampling: True
num_idx: 1000
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 1.0 # different from cogvideox_2b_infer.yaml
sampler_config:
target: sgm.modules.diffusionmodules.sampling.VPSDEDPMPP2MSampler
params:
num_steps: 50
verbose: True
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.ZeroSNRDDPMDiscretization
params:
shift_scale: 1.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.DynamicCFG
params:
scale: 6
exp: 5
num_steps: 50
\ No newline at end of file
import io
import os
import sys
import glob
import torchvision
from functools import partial
import math
import torchvision.transforms as TT
from einops import rearrange
from sgm.webds import MetaDistributedWebDataset
import random
from fractions import Fraction
from typing import Union, Optional, Dict, Any, Tuple
from torchvision.io.video import av
import numpy as np
import torch
from torchvision.io import _video_opt
from torchvision.io.video import _check_av_available, _read_from_stream, _align_audio_frames
from torchvision.transforms.functional import center_crop, resize
from torchvision.transforms import InterpolationMode
import decord
from decord import VideoReader
from torch.utils.data import Dataset
import random
import torch.nn.functional as F
def read_video(
filename: str,
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts",
output_format: str = "THWC",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Reads a video from a file, returning both the video frames and the audio frames
Args:
filename (str): path to the video file
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
The start presentation time of the video
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
The end presentation time
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
either 'pts' or 'sec'. Defaults to 'pts'.
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
Returns:
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
"""
output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
_check_av_available()
if end_pts is None:
end_pts = float("inf")
if end_pts < start_pts:
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
info = {}
audio_frames = []
audio_timebase = _video_opt.default_timebase
with av.open(filename, metadata_errors="ignore") as container:
if container.streams.audio:
audio_timebase = container.streams.audio[0].time_base
if container.streams.video:
video_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.video[0],
{"video": 0},
)
video_fps = container.streams.video[0].average_rate
# guard against potentially corrupted files
if video_fps is not None:
info["video_fps"] = float(video_fps)
if container.streams.audio:
audio_frames = _read_from_stream(
container,
start_pts,
end_pts,
pts_unit,
container.streams.audio[0],
{"audio": 0},
)
info["audio_fps"] = container.streams.audio[0].rate
aframes_list = [frame.to_ndarray() for frame in audio_frames]
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
if aframes_list:
aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes)
if pts_unit == "sec":
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
if end_pts != float("inf"):
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
if output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
vframes = vframes.permute(0, 3, 1, 2)
return vframes, aframes, info
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)
h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)
delta_h = h - image_size[0]
delta_w = w - image_size[1]
if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr
def pad_last_frame(tensor, num_frames):
# T, H, W, C
if tensor.shape[0] < num_frames:
last_frame = tensor[-int(num_frames - tensor.shape[1]) :]
padded_tensor = torch.cat([tensor, last_frame], dim=0)
return padded_tensor
else:
return tensor[:num_frames]
def load_video(
video_data,
sampling="uniform",
duration=None,
num_frames=4,
wanted_fps=None,
actual_fps=None,
skip_frms_num=0.0,
nb_read_frames=None,
):
decord.bridge.set_bridge("torch")
vr = VideoReader(uri=video_data, height=-1, width=-1)
if nb_read_frames is not None:
ori_vlen = nb_read_frames
else:
ori_vlen = min(int(duration * actual_fps) - 1, len(vr))
max_seek = int(ori_vlen - skip_frms_num - num_frames / wanted_fps * actual_fps)
start = random.randint(skip_frms_num, max_seek + 1)
end = int(start + num_frames / wanted_fps * actual_fps)
n_frms = num_frames
if sampling == "uniform":
indices = np.arange(start, end, (end - start) / n_frms).astype(int)
else:
raise NotImplementedError
# get_batch -> T, H, W, C
temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
return pad_last_frame(tensor_frms, num_frames)
import threading
def load_video_with_timeout(*args, **kwargs):
video_container = {}
def target_function():
video = load_video(*args, **kwargs)
video_container["video"] = video
thread = threading.Thread(target=target_function)
thread.start()
timeout = 20
thread.join(timeout)
if thread.is_alive():
print("Loading video timed out")
raise TimeoutError
return video_container.get("video", None).contiguous()
def process_video(
video_path,
image_size=None,
duration=None,
num_frames=4,
wanted_fps=None,
actual_fps=None,
skip_frms_num=0.0,
nb_read_frames=None,
):
"""
video_path: str or io.BytesIO
image_size: .
duration: preknow the duration to speed up by seeking to sampled start. TODO by_pass if unknown.
num_frames: wanted num_frames.
wanted_fps: .
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
"""
video = load_video_with_timeout(
video_path,
duration=duration,
num_frames=num_frames,
wanted_fps=wanted_fps,
actual_fps=actual_fps,
skip_frms_num=skip_frms_num,
nb_read_frames=nb_read_frames,
)
# --- copy and modify the image process ---
video = video.permute(0, 3, 1, 2) # [T, C, H, W]
# resize
if image_size is not None:
video = resize_for_rectangle_crop(video, image_size, reshape_mode="center")
return video
def process_fn_video(src, image_size, fps, num_frames, skip_frms_num=0.0, txt_key="caption"):
while True:
r = next(src)
if "mp4" in r:
video_data = r["mp4"]
elif "avi" in r:
video_data = r["avi"]
else:
print("No video data found")
continue
if txt_key not in r:
txt = ""
else:
txt = r[txt_key]
if isinstance(txt, bytes):
txt = txt.decode("utf-8")
else:
txt = str(txt)
duration = r.get("duration", None)
if duration is not None:
duration = float(duration)
else:
continue
actual_fps = r.get("fps", None)
if actual_fps is not None:
actual_fps = float(actual_fps)
else:
continue
required_frames = num_frames / fps * actual_fps + 2 * skip_frms_num
required_duration = num_frames / fps + 2 * skip_frms_num / actual_fps
if duration is not None and duration < required_duration:
continue
try:
frames = process_video(
io.BytesIO(video_data),
num_frames=num_frames,
wanted_fps=fps,
image_size=image_size,
duration=duration,
actual_fps=actual_fps,
skip_frms_num=skip_frms_num,
)
frames = (frames - 127.5) / 127.5
except Exception as e:
print(e)
continue
item = {
"mp4": frames,
"txt": txt,
"num_frames": num_frames,
"fps": fps,
}
yield item
class VideoDataset(MetaDistributedWebDataset):
def __init__(
self,
path,
image_size,
num_frames,
fps,
skip_frms_num=0.0,
nshards=sys.maxsize,
seed=1,
meta_names=None,
shuffle_buffer=1000,
include_dirs=None,
txt_key="caption",
**kwargs,
):
if seed == -1:
seed = random.randint(0, 1000000)
if meta_names is None:
meta_names = []
if path.startswith(";"):
path, include_dirs = path.split(";", 1)
super().__init__(
path,
partial(
process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num
),
seed,
meta_names=meta_names,
shuffle_buffer=shuffle_buffer,
nshards=nshards,
include_dirs=include_dirs,
)
@classmethod
def create_dataset_function(cls, path, args, **kwargs):
return cls(path, **kwargs)
class SFTDataset(Dataset):
def __init__(self, data_dir, video_size, fps, max_num_frames, skip_frms_num=3):
"""
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
"""
super(SFTDataset, self).__init__()
self.videos_list = []
self.captions_list = []
self.num_frames_list = []
self.fps_list = []
decord.bridge.set_bridge("torch")
for root, dirnames, filenames in os.walk(data_dir):
for filename in filenames:
if filename.endswith(".mp4"):
video_path = os.path.join(root, filename)
vr = VideoReader(uri=video_path, height=-1, width=-1)
actual_fps = vr.get_avg_fps()
ori_vlen = len(vr)
if ori_vlen / actual_fps * fps > max_num_frames:
num_frames = max_num_frames
start = int(skip_frms_num)
end = int(start + num_frames / fps * actual_fps)
indices = np.arange(start, end, (end - start) / num_frames).astype(int)
temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
else:
if ori_vlen > max_num_frames:
num_frames = max_num_frames
start = int(skip_frms_num)
end = int(ori_vlen - skip_frms_num)
indices = np.arange(start, end, (end - start) / num_frames).astype(int)
temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None
tensor_frms = (
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
)
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())]
else:
def nearest_smaller_4k_plus_1(n):
remainder = n % 4
if remainder == 0:
return n - 3
else:
return n - remainder + 1
start = int(skip_frms_num)
end = int(ori_vlen - skip_frms_num)
num_frames = nearest_smaller_4k_plus_1(
end - start
) # 3D VAE requires the number of frames to be 4k+1
end = int(start + num_frames)
temp_frms = vr.get_batch(np.arange(start, end))
assert temp_frms is not None
tensor_frms = (
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms
)
tensor_frms = pad_last_frame(
tensor_frms, num_frames
) # the len of indices may be less than num_frames, due to round error
tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W]
tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center")
tensor_frms = (tensor_frms - 127.5) / 127.5
self.videos_list.append(tensor_frms)
# caption
caption_path = os.path.join(root, filename.replace(".mp4", ".txt")).replace("videos", "labels")
if os.path.exists(caption_path):
caption = open(caption_path, "r").read().splitlines()[0]
else:
caption = ""
self.captions_list.append(caption)
self.num_frames_list.append(num_frames)
self.fps_list.append(fps)
def __getitem__(self, index):
item = {
"mp4": self.videos_list[index],
"txt": self.captions_list[index],
"num_frames": self.num_frames_list[index],
"fps": self.fps_list[index],
}
return item
def __len__(self):
return len(self.fps_list)
@classmethod
def create_dataset_function(cls, path, args, **kwargs):
return cls(data_dir=path, **kwargs)
class PairedCaptionDataset(Dataset):
def __init__(self, data_dir, video_size=None, fps=None, max_num_frames=None, skip_frms_num=None, null_text_ratio=0.5, num_frames=25):
"""
skip_frms_num: ignore the first and the last xx frames, avoiding transitions.
"""
super(PairedCaptionDataset, self).__init__()
self.null_text_ratio = null_text_ratio
self.num_frames = num_frames
self.lr_list = []
self.gt_list = []
self.tag_path_list = []
lr_path = data_dir + '/lq'
tag_path = data_dir + '/text'
self.tag_path = tag_path
gt_path = data_dir + '/gt'
self.lr_list += glob.glob(os.path.join(lr_path, '*.mp4'))
self.gt_list += glob.glob(os.path.join(gt_path, '*.mp4'))
self.tag_path_list += glob.glob(os.path.join(tag_path, '*.txt'))
assert len(self.lr_list) == len(self.gt_list)
assert len(self.lr_list) == len(self.tag_path_list)
def __getitem__(self, index):
gt_path = self.gt_list[index]
vframes_gt, _, info = torchvision.io.read_video(filename=gt_path, pts_unit="sec", output_format="TCHW")
fps = info['video_fps']
if vframes_gt.shape[-1] > 720:
vframes_gt = F.interpolate(vframes_gt, scale_factor=2 / 3, mode='bilinear')
vframes_gt = TT.functional.center_crop(vframes_gt, (480, 720))
# elif vframes_gt.shape[-1] < 720:
# vframes_gt = F.interpolate(vframes_gt, size=(480, 720), mode='bilinear')
vframes_gt = (vframes_gt / 255) * 2 - 1
lq_path = self.lr_list[index]
vframes_lq, _, _ = torchvision.io.read_video(filename=lq_path, pts_unit="sec", output_format="TCHW")
if vframes_lq.shape[-1] > 720:
vframes_lq = F.interpolate(vframes_lq, scale_factor=2 / 3, mode='bilinear')
vframes_lq = TT.functional.center_crop(vframes_lq, (480, 720))
elif vframes_lq.shape[-1] < 720:
vframes_lq = F.interpolate(vframes_lq, scale_factor=4, mode='bicubic')
vframes_lq = (vframes_lq / 255) * 2 - 1
if random.random() < self.null_text_ratio:
tag = ''
else:
tag_path = os.path.join(self.tag_path, os.path.splitext(os.path.basename(gt_path))[0] + '.txt')
file = open(tag_path, 'r')
tag = file.read()
file.close()
return {
"mp4": vframes_gt[:self.num_frames, :, :, :],
"txt": tag,
"lq": vframes_lq[:self.num_frames, :, :, :], # frames = 4k + 1
"num_frames": self.num_frames,
"fps": fps,
}
def __len__(self):
return len(self.gt_list)
@classmethod
def create_dataset_function(cls, path, args, **kwargs):
return cls(data_dir=path, **kwargs)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment