Commit e2ec88b4 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #3335 canceled with stages
[中文阅读](./README_zh.md)
# HunyuanVideo Latent Feature Extraction Tool
This project provides an efficient tool for extracting latent features from videos, preparing them for subsequent video generation and processing tasks.
## Features
- Support for various video formats and resolutions
- Multi-GPU parallel processing for improved efficiency
- Support for multiple aspect ratios
- High-performance VAE model for feature extraction
- Automatic skipping of already processed videos, supporting resume functionality
## Usage
### 1. Configuration File
## Input dataset Format
The input video metadata file (meta_file.list) should be a list of JSON file paths, with each JSON file containing the following fields:
The format of meta_file.list (e.g., ./assets/demo/i2v_lora/train_dataset/meta_file.list) is as follows
```
/path/to/0.json
/path/to/1.json
/path/to/2.json
...
```
`IMPORTANT: Make sure each video's video_id is unique!!!`
The format of /path/to/0.json (e.g., ./assets/demo/i2v_lora/train_dataset/meta_data.json) is as follows
```json
{
"video_path": "/path/to/video.mp4",
"raw_caption": {
"long caption": "Detailed description text of the video"
}
}
```
Configure parameters in `hyvideo/hyvae_extract/vae.yaml`:
```yaml
vae_path: "./ckpts/hunyuan-video-i2v-720p/vae" # VAE model path
video_url_files: "/path/to/meta_file.list" # Video metadata file list
output_base_dir: "/path/to/output/directory" # Output directory
sample_n_frames: 129 # Number of frames to sample
target_size: # Target size
- bucket_size
- bucket_size
enable_multi_aspect_ratio: True # Enable multiple aspect ratios
use_stride: True # Use stride sampling
```
#### Bucket Size Reference
The `target_size` parameter defines the resolution bucket size. Here are the recommended values for different quality levels:
| Quality | Bucket Size | Typical Resolution |
|---------|-------------|-------------------|
| 720p | 960 | 1280×720 or similar |
| 540p | 720 | 960×540 or similar |
| 360p | 480 | 640×360 or similar |
When `enable_multi_aspect_ratio` is set to `True`, the system will use these bucket sizes as a base to generate multiple aspect ratio buckets. For optimal performance, choose a bucket size that balances quality and memory usage based on your hardware capabilities.
### 2. Run Extraction
```bash
# Set environment variables
export HOST_GPU_NUM=8 # Set the number of GPUs to use
# Run extraction script
cd HunyuanVideo-I2V
bash hyvideo/hyvae_extract/start.sh
```
### 3. Single GPU Run
```bash
cd HunyuanVideo-I2V
export PYTHONPATH=${PYTHONPATH}:`pwd`
export HOST_GPU_NUM=1
CUDA_VISIBLE_DEVICES=0 python3 -u hyvideo/hyvae_extract/run.py --local_rank 0 --config 'hyvideo/hyvae_extract/vae.yaml'
```
## Output Files
The program generates the following files in the specified output directory:
1. `{video_id}.npy` - Latent feature array of the video
2. `json_path/{video_id}.json` - JSON file containing video metadata, including:
- video_id: Video ID
- latent_shape: Shape of the latent features
- video_path: Original video path
- prompt: Video description/prompt
- npy_save_path: Path where the latent features are saved
```
output_base_dir/
├── {video_id_1}.npy # Latent feature array for video 1
├── {video_id_2}.npy # Latent feature array for video 2
├── {video_id_3}.npy # Latent feature array for video 3
│ ...
├── {video_id_n}.npy # Latent feature array for video n
└── json_path/ # Directory containing metadata JSON files
│ ├── {video_id_1}.json # Metadata for video 1
│ ├── {video_id_2}.json # Metadata for video 2
│ ├── {video_id_3}.json # Metadata for video 3
│ │ ...
│ └── {video_id_n}.json # Metadata for video n
```
## Advanced Configuration
### Multiple Aspect Ratio Processing
When `enable_multi_aspect_ratio` is set to `True`, the system selects the target size closest to the original aspect ratio of the video, rather than forcing it to be cropped to a fixed size. This is useful for maintaining the integrity of the video content.
### Stride Sampling
When `use_stride` is set to `True`, the system automatically adjusts the sampling stride based on the video's frame rate:
- When frame rate >= 50fps, stride is 2
- When frame rate < 50fps, stride is 1
\ No newline at end of file
[English](./README.md)
# 混元视频特征提取工具
本项目提供了一个高效的工具,用于从视频中提取潜在特征,为后续的视频生成和处理任务做准备。
## 功能特点
- 支持各种视频格式和分辨率
- 多GPU并行处理,提高效率
- 支持多种宽高比
- 高性能VAE模型用于特征提取
- 自动跳过已处理的视频,支持断点续传功能
## 使用方法
### 1. 配置文件
## 输入数据集格式
输入的视频元数据文件(meta_file.list)应为JSON文件路径的列表,每个JSON文件包含以下字段:
meta_file.list的格式(例如,./assets/demo/i2v_lora/train_dataset/meta_file.list)如下:
```
/path/to/0.json
/path/to/1.json
/path/to/2.json
...
```
`重要: 确保每个视频的名字是唯一的!!!`
/path/to/0.json的格式(例如,./assets/demo/i2v_lora/train_dataset/meta_data.json)如下:
```json
{
"video_path": "/path/to/video.mp4",
"raw_caption": {
"long caption": "视频的详细描述文本"
}
}
```
`hyvideo/hyvae_extract/vae.yaml`中配置参数:
```yaml
vae_path: "./ckpts/hunyuan-video-i2v-720p/vae" # VAE模型路径
video_url_files: "/path/to/meta_file.list" # 视频元数据文件列表
output_base_dir: "/path/to/output/directory" # 输出目录
sample_n_frames: 129 # 采样帧数
target_size: # 目标尺寸
- bucket_size
- bucket_size
enable_multi_aspect_ratio: True # 启用多种宽高比
use_stride: True # 使用步长采样
```
#### 分辨率桶大小参考
`target_size`参数定义了分辨率桶大小。以下是不同质量级别的推荐值:
| 质量 | 桶大小 | 典型分辨率 |
|---------|-------------|-------------------|
| 720p | 960 | 1280×720或类似 |
| 540p | 720 | 960×540或类似 |
| 360p | 480 | 640×360或类似 |
`enable_multi_aspect_ratio`设置为`True`时,系统将使用这些桶大小作为基础来生成多种宽高比的桶。为了获得最佳性能,请根据您的硬件能力选择平衡质量和内存使用的桶大小。
### 2. 运行提取
```bash
# 设置环境变量
export HOST_GPU_NUM=8 # 设置要使用的GPU数量
# 运行提取脚本
cd HunyuanVideo-I2V
bash hyvideo/hyvae_extract/start.sh
```
### 3. 单GPU运行
```bash
cd HunyuanVideo-I2V
export PYTHONPATH=${PYTHONPATH}:`pwd`
export HOST_GPU_NUM=1
CUDA_VISIBLE_DEVICES=0 python3 -u hyvideo/hyvae_extract/run.py --local_rank 0 --config 'hyvideo/hyvae_extract/vae.yaml'
```
## 输出文件
程序在指定的输出目录中生成以下文件:
1. `{video_id}.npy` - 视频的潜在特征数组
2. `json_path/{video_id}.json` - 包含视频元数据的JSON文件,包括:
- video_id: 视频ID
- latent_shape: 潜在特征的形状
- video_path: 原始视频路径
- prompt: 视频描述/提示
- npy_save_path: 保存潜在特征的路径
```
output_base_dir/
├── {video_id_1}.npy # 视频1的潜在特征数组
├── {video_id_2}.npy # 视频2的潜在特征数组
├── {video_id_3}.npy # 视频3的潜在特征数组
│ ...
├── {video_id_n}.npy # 视频n的潜在特征数组
└── json_path/ # 包含元数据JSON文件的目录
├── {video_id_1}.json # 视频1的元数据
├── {video_id_2}.json # 视频2的元数据
├── {video_id_3}.json # 视频3的元数据
│ ...
└── {video_id_n}.json # 视频n的元数据
```
## 高级配置
### 多宽高比处理
`enable_multi_aspect_ratio`设置为`True`时,系统会选择最接近视频原始宽高比的目标尺寸,而不是强制将其裁剪为固定尺寸。这有助于保持视频内容的完整性。
### 步长采样
`use_stride`设置为`True`时,系统会根据视频的帧率自动调整采样步长:
- 当帧率 >= 50fps时,步长为2
- 当帧率 < 50fps时,步长为1
\ No newline at end of file
from typing import Tuple, List
from decord import VideoReader
import urllib
import io
import os
import csv
import numpy as np
import torch
from torch.utils.data import Dataset, IterableDataset
import torchvision.transforms as transforms
from torchvision.transforms.functional import crop
from pathlib import Path
import sys
import json
def split_video_urls(meta_files: str, global_rank: int, world_size: int):
meta_paths = []
meta_paths.extend([line.strip() for line in open(meta_files, 'r').readlines()])
num_videos = len(meta_paths)
num_videos_per_rank = num_videos // world_size
remainder = num_videos % world_size
# Calculate start and end indices
start = num_videos_per_rank * global_rank + min(global_rank, remainder)
end = start + num_videos_per_rank + (1 if global_rank < remainder else 0)
return start, end, meta_paths[start:end]
class MultiBucketDataset(IterableDataset):
def __init__(self, source: Dataset, batch_size: int, max_buf = 64):
super().__init__()
self.source = source
self.batch_size = batch_size
self.buffer = {}
self.max_buf = max_buf
self.size = 0
@staticmethod
def collate_fn(samples):
pixel_values = torch.stack([sample["pixel_values"] for sample in samples]).contiguous()
videoid = [sample["videoid"] for sample in samples]
valid = [sample["valid"] for sample in samples]
batch = {"pixel_values": pixel_values, "videoid": videoid, "valid": valid}
return batch
def __iter__(self):
# split dataset
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
iter_start = 0
iter_end = len(self.source)
else:
worker_id = int(worker_info.id)
per_worker = len(self.source) // int(worker_info.num_workers)
per_worker += int(worker_id < len(self.source) % int(worker_info.num_workers))
if worker_id >= len(self.source) % int(worker_info.num_workers):
iter_start = worker_id * per_worker + len(self.source) % int(worker_info.num_workers)
else:
iter_start = worker_id * per_worker
iter_end = iter_start + per_worker
# bucketing
for i in range(iter_start, iter_end):
sample = self.source[i]
if sample["valid"] is False:
continue
T, C, H, W = sample["pixel_values"].shape
if (T, H, W) not in self.buffer:
self.buffer[(T, H, W)] = []
self.buffer[(T, H, W)].append(sample)
self.size += 1
if len(self.buffer[(T, H, W)]) == self.batch_size:
yield self.buffer[(T, H, W)]
self.size -= self.batch_size
self.buffer[(T, H, W)] = []
if self.size > self.max_buf and (len(self.buffer[(T, H, W)]) > 0):
self.size -= len(self.buffer[(T, H, W)])
yield self.buffer[(T, H, W)]
self.buffer[(T, H, W)] = []
# yield the remaining batch
for bucket, samples in self.buffer.items():
if len(samples) > 0:
yield samples
class VideoDataset(Dataset):
def __init__(
self,
meta_files: List[str],
latent_cache_dir: str,
sample_size: Tuple[int, int],
sample_n_frames: int,
is_center_crop: bool = True,
enable_multi_aspect_ratio: bool = False,
vae_time_compression_ratio: int = 4,
use_stride: bool = False,
):
if not Path(latent_cache_dir).exists():
Path(latent_cache_dir).mkdir(parents=True, exist_ok=True)
self.latent_cache_dir = latent_cache_dir
self.sample_n_frames = sample_n_frames
self.sample_size = tuple(sample_size)
self.is_center_crop = is_center_crop
self.vae_time_compression_ratio = vae_time_compression_ratio
self.enable_multi_aspect_ratio = enable_multi_aspect_ratio
self.dataset = meta_files
self.length = len(self.dataset)
self.use_stride = use_stride
# multi-aspect-ratio buckets
if enable_multi_aspect_ratio:
assert self.sample_size[0] == self.sample_size[1]
if self.sample_size[0] < 540:
self.buckets = self.generate_crop_size_list(base_size=self.sample_size[0])
else:
self.buckets = self.generate_crop_size_list(base_size=self.sample_size[0], patch_size=32)
self.aspect_ratios = np.array([float(w) / float(h) for w, h in self.buckets])
print(f"Multi-aspect-ratio bucket num: {len(self.buckets)}")
# image preprocess
if not enable_multi_aspect_ratio:
self.train_crop = transforms.CenterCrop(self.sample_size) if self.is_center_crop else transforms.RandomCrop(self.sample_size)
def request_ceph_data(self, path):
try:
video_reader = VideoReader(path)
except Exception as e:
print(f"Error: {e}")
raise
return video_reader
def preprocess_url(self, data_json_path):
with open(data_json_path, "r") as f:
data_dict = json.load(f)
video_path = data_dict['video_path']
video_id = video_path.split('/')[-1].split('.')[0]
prompt = data_dict['raw_caption']["long caption"]
item = {"video_path": video_path, "videoid": video_id, "prompt": prompt}
return item
def get_item(self, idx):
# Create Video Reader
data_json_path = self.dataset[idx]
video_item = self.preprocess_url(data_json_path)
# 20250322 pftq: fixed to return 5 values for consistency and "not enough values to unpack" error
# Skip if exists
latent_save_path = Path(self.latent_cache_dir) / f"{video_item['videoid']}.npy"
if latent_save_path.exists():
return None, None, None, None, False
video_reader = self.request_ceph_data(video_item["video_path"])
fps = video_reader.get_avg_fps()
stride = 1
if self.use_stride:
if int(fps) >= 50:
stride = 2
else:
stride = 1
else:
stride = 1
video_length = len(video_reader)
if video_length < self.sample_n_frames*stride:
sample_n_frames = video_length - (video_length - 1) % (self.vae_time_compression_ratio*stride) # 4n+1/8n+1
else:
sample_n_frames = self.sample_n_frames*stride
start_idx = 0
batch_index = list(range(start_idx, start_idx + sample_n_frames, stride))
# 20250322 pftq: fixed to return 5 values for consistency and "not enough values to unpack" error
if len(batch_index) == 0:
print(f"get video len=0, skip for {video_item['video_path']}")
return None, video_item["videoid"], video_item["video_path"], video_item["prompt"], False
# Read frames
try:
video_images = video_reader.get_batch(batch_index).asnumpy()
except Exception as e:
print(f'Error: {e}, video_path: {video_item["video_path"]}')
raise
pixel_values = torch.from_numpy(video_images).permute(0, 3, 1, 2).contiguous()
del video_reader
return pixel_values, video_item["videoid"], video_item["video_path"], video_item["prompt"], True
def preprocess_train(self, frames):
height, width = frames.shape[-2:]
# Resize & Crop
if self.enable_multi_aspect_ratio:
bw, bh = self.get_closest_ratio(width=width, height=height, ratios=self.aspect_ratios, buckets=self.buckets)
sample_size = bh, bw
target_size = self.get_target_size(frames, sample_size)
train_crop = transforms.CenterCrop(sample_size) if self.is_center_crop else transforms.RandomCrop(sample_size)
else:
sample_size = self.sample_size
target_size = self.get_target_size(frames, sample_size)
train_crop = self.train_crop
frames = transforms.Resize(target_size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)(frames)
if self.is_center_crop:
y1 = max(0, int(round((height - sample_size[0]) / 2.0)))
x1 = max(0, int(round((width - sample_size[1]) / 2.0)))
frames = train_crop(frames)
else:
y1, x1, h, w = train_crop.get_params(frames, sample_size)
frames = crop(frames, y1, x1, h, w)
return frames
@staticmethod
def get_closest_ratio(width: float, height: float, ratios: list, buckets: list):
aspect_ratio = float(width) / float(height)
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
return buckets[closest_ratio_id]
@staticmethod
def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0):
num_patches = round((base_size / patch_size) ** 2)
assert max_ratio >= 1.
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def get_target_size(self, frames, target_size):
T, C, H, W = frames.shape
th, tw = target_size
r = max(th / H, tw / W)
target_size = int(H * r), int(W * r)
return target_size
def __len__(self):
return self.length
def __getitem__(self, idx):
try:
pixel, videoid, video_path, prompt, valid = self.get_item(idx)
if pixel is not None and valid:
pixel = self.preprocess_train(pixel)
sample = dict(pixel_values=pixel, videoid=videoid, video_path=video_path, prompt=prompt,valid=valid)
return sample
except Exception as e:
print(e)
return dict(pixel_values=None, videoid=None, video_path=None, prompt=None, valid=False)
from typing import Tuple, List, Dict
import sys
from pathlib import Path
import argparse
import time
import os
import traceback
import random
import numpy as np
from einops import rearrange
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from dataset import VideoDataset, MultiBucketDataset, split_video_urls
import json
import glob
from omegaconf import OmegaConf
from hyvideo.vae import load_vae
DEVICE = "cuda"
DTYPE = torch.float16
def seed_everything(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
@torch.no_grad()
def extract(
vae: torch.nn.Module,
meta_files: List[str],
output_base_dir: str,
sample_n_frames: int,
target_size: Tuple[int, int],
enable_multi_aspect_ratio: bool = False,
use_stride: bool = False,
batch_size=None,
):
dataset = VideoDataset(
meta_files=meta_files,
latent_cache_dir=output_base_dir,
sample_size=target_size,
sample_n_frames=sample_n_frames,
is_center_crop=True,
enable_multi_aspect_ratio=enable_multi_aspect_ratio,
vae_time_compression_ratio=vae.time_compression_ratio,
use_stride=use_stride
)
if batch_size is not None:
dataset = MultiBucketDataset(dataset, batch_size=batch_size)
dataloader = DataLoader(
dataset,
batch_size=None,
collate_fn=dataset.collate_fn if batch_size is not None else None,
shuffle=False,
num_workers=8,
prefetch_factor=4,
pin_memory=False,
)
normalize_fn = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
save_json_path = Path(output_base_dir) / "json_path"
if not os.path.exists(save_json_path):
os.makedirs(save_json_path, exist_ok=True)
for i, item in enumerate(dataloader):
print(f"processing video latent extraction {i}")
if batch_size is None:
if item.get("valid", True) is False:
continue
item["videoid"] = [item["videoid"]]
item["valid"] = [item["valid"]]
item["prompt"] = [item["prompt"]]
try:
pixel_values = item["pixel_values"]
pixel_values = pixel_values.to(device=vae.device, dtype=vae.dtype)
pixel_values = pixel_values / 255.
pixel_values = normalize_fn(pixel_values)
if pixel_values.ndim == 4:
pixel_values = pixel_values.unsqueeze(0)
pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
z = vae.encode(pixel_values).latent_dist.mode()
z = z.detach().to(DTYPE).cpu().numpy()
assert z.shape[0] == len(item["videoid"])
for k in range(z.shape[0]):
save_path = Path(output_base_dir) / f"{item['videoid'][k]}.npy"
np.save(save_path, z[k][None, ...])
data = {"video_id": item["videoid"][k],
"latent_shape": z[k][None,...].shape,
"video_path": item["video_path"][k],
"prompt": item["prompt"][k],
"npy_save_path": str(save_path)}
with open(save_json_path / f"{item['videoid'][k]}.json", "w", encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False)
except Exception as e:
traceback.print_exc()
def main(
local_rank: int,
vae_path: str,
meta_files: str,
output_base_dir: str,
sample_n_frames: int,
target_size: Tuple[int, int],
enable_multi_aspect_ratio: bool = False,
use_stride: bool = False,
seed: int = 42,
):
seed_everything(seed)
global_rank = local_rank
world_size = int(os.environ["HOST_GPU_NUM"])
print(f"split video urls")
start, end, meta_files = split_video_urls(meta_files, global_rank, world_size)
print(f"Load VAE")
vae, vae_path, spatial_compression_ratio, time_compression_ratio = load_vae(
vae_type="884-16c-hy",
vae_precision='fp16',
vae_path=vae_path,
device=DEVICE,
)
# vae.enable_temporal_tiling()
vae.enable_spatial_tiling()
vae.eval()
print(f"processing video latent extraction")
extract(vae, meta_files, output_base_dir, sample_n_frames, target_size, enable_multi_aspect_ratio, use_stride)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, required=True)
parser.add_argument("--config", default='./vae.yaml', type=str)
args = parser.parse_args()
config = OmegaConf.load(args.config)
vae_path = config.vae_path
sample_n_frames = config.sample_n_frames
target_size = [config.target_size[0], config.target_size[1]]
enable_multi_aspect_ratio = config.enable_multi_aspect_ratio
output_base_dir = config.output_base_dir
use_stride = config.use_stride
meta_files = config.video_url_files
main(args.local_rank, vae_path, meta_files, output_base_dir, sample_n_frames, target_size, enable_multi_aspect_ratio, use_stride)
\ No newline at end of file
export PYTHONPATH=${PYTHONPATH}:`pwd`
for ((i=0;i<$HOST_GPU_NUM;++i)); do
CUDA_VISIBLE_DEVICES=$i python3 -u hyvideo/hyvae_extract/run.py --local_rank $i --config 'hyvideo/hyvae_extract/vae.yaml'&
done
# CUDA_VISIBLE_DEVICES=0 python3 -u hyvideo/hyvae_extract/run.py --local_rank 0 --config 'hyvideo/hyvae_extract/vae.yaml'&
wait
echo "Finished."
vae_path: "./ckpts/hunyuan-video-i2v-720p/vae"
video_url_files: "/path/to/meta_file.list"
output_base_dir: "/path/to/output/directory"
sample_n_frames: 129
target_size:
- 480
- 480
enable_multi_aspect_ratio: True
use_stride: True
\ No newline at end of file
import os
import time
import random
import functools
from typing import List, Optional, Tuple, Union
from pathlib import Path
from loguru import logger
import torch
import torch.distributed as dist
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V
from hyvideo.vae import load_vae
from hyvideo.modules import load_model
from hyvideo.text_encoder import TextEncoder
from hyvideo.utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list
from hyvideo.utils.lora_utils import load_lora_for_pipeline
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
from hyvideo.modules.fp8_optimization import convert_fp8_linear
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from safetensors.torch import load_file
try:
import xfuser
from xfuser.core.distributed import (
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
initialize_model_parallel,
init_distributed_environment
)
except:
xfuser = None
get_sequence_parallel_world_size = None
get_sequence_parallel_rank = None
get_sp_group = None
initialize_model_parallel = None
init_distributed_environment = None
###############################################
# 20250308 pftq: Riflex workaround to fix 192-frame-limit bug, credit to Kijai for finding it in ComfyUI and thu-ml for making it
# https://github.com/thu-ml/RIFLEx/blob/main/riflex_utils.py
from diffusers.models.embeddings import get_1d_rotary_pos_embed
import numpy as np
from typing import Union,Optional
def get_1d_rotary_pos_embed_riflex(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
k: Optional[int] = None,
L_test: Optional[int] = None,
):
"""
RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
L_test (`int`, *optional*, defaults to None): the number of frames for inference
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim)
) # [D/2]
# === Riflex modification start ===
# Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
# Empirical observations show that a few videos may exhibit repetition in the tail frames.
# To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
if k is not None:
freqs[k-1] = 0.9 * 2 * torch.pi / L_test
# === Riflex modification end ===
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
###############################################
def parallelize_transformer(pipe):
transformer = pipe.transformer
original_forward = transformer.forward
@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
x: torch.Tensor,
t: torch.Tensor, # Should be in range(0, 1000).
text_states: torch.Tensor = None,
text_mask: torch.Tensor = None, # Now we don't use it.
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
return_dict: bool = True,
):
if x.shape[-2] // 2 % get_sequence_parallel_world_size() == 0:
# try to split x by height
split_dim = -2
elif x.shape[-1] // 2 % get_sequence_parallel_world_size() == 0:
# try to split x by width
split_dim = -1
else:
raise ValueError(f"Cannot split video sequence into ulysses_degree x ring_degree ({get_sequence_parallel_world_size()}) parts evenly")
# patch sizes for the temporal, height, and width dimensions are 1, 2, and 2.
temporal_size, h, w = x.shape[2], x.shape[3] // 2, x.shape[4] // 2
x = torch.chunk(x, get_sequence_parallel_world_size(),dim=split_dim)[get_sequence_parallel_rank()]
dim_thw = freqs_cos.shape[-1]
freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw)
freqs_cos = torch.chunk(freqs_cos, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
freqs_cos = freqs_cos.reshape(-1, dim_thw)
dim_thw = freqs_sin.shape[-1]
freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw)
freqs_sin = torch.chunk(freqs_sin, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()]
freqs_sin = freqs_sin.reshape(-1, dim_thw)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
for block in transformer.double_blocks + transformer.single_blocks:
block.hybrid_seq_parallel_attn = xFuserLongContextAttention()
output = original_forward(
x,
t,
text_states,
text_mask,
text_states_2,
freqs_cos,
freqs_sin,
guidance,
return_dict,
)
return_dict = not isinstance(output, tuple)
sample = output["x"]
sample = get_sp_group().all_gather(sample, dim=split_dim)
output["x"] = sample
return output
new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward
class Inference(object):
def __init__(
self,
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=None,
pipeline=None,
use_cpu_offload=False,
device=None,
logger=None,
parallel_args=None,
):
self.vae = vae
self.vae_kwargs = vae_kwargs
self.text_encoder = text_encoder
self.text_encoder_2 = text_encoder_2
self.model = model
self.pipeline = pipeline
self.use_cpu_offload = use_cpu_offload
self.args = args
self.device = (
device
if device is not None
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
self.logger = logger
self.parallel_args = parallel_args
# 20250316 pftq: Fixed multi-GPU loading times going up to 20 min due to loading contention by loading models only to one GPU and braodcasting to the rest.
@classmethod
def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs):
"""
Initialize the Inference pipeline.
Args:
pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
args (argparse.Namespace): The arguments for the pipeline.
device (int): The device for inference. Default is None.
"""
logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
# ========================================================================
# Initialize Distributed Environment
# ========================================================================
# 20250316 pftq: Modified to extract rank and world_size early for sequential loading
if args.ulysses_degree > 1 or args.ring_degree > 1:
assert xfuser is not None, "Ulysses Attention and Ring Attention requires xfuser package."
assert args.use_cpu_offload is False, "Cannot enable use_cpu_offload in the distributed environment."
# 20250316 pftq: Set local rank and device explicitly for NCCL
local_rank = int(os.environ['LOCAL_RANK'])
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(local_rank) # 20250316 pftq: Set CUDA device explicitly
dist.init_process_group("nccl") # 20250316 pftq: Removed device_id, rely on set_device
rank = dist.get_rank()
world_size = dist.get_world_size()
assert world_size == args.ring_degree * args.ulysses_degree, \
"number of GPUs should be equal to ring_degree * ulysses_degree."
init_distributed_environment(rank=rank, world_size=world_size)
initialize_model_parallel(
sequence_parallel_degree=world_size,
ring_degree=args.ring_degree,
ulysses_degree=args.ulysses_degree,
)
else:
rank = 0 # 20250316 pftq: Default rank for single GPU
world_size = 1 # 20250316 pftq: Default world_size for single GPU
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
parallel_args = {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree}
torch.set_grad_enabled(False)
# ========================================================================
# Build main model, VAE, and text encoder sequentially on rank 0
# ========================================================================
# 20250316 pftq: Load models only on rank 0, then broadcast
if rank == 0:
logger.info("Building model...")
factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]}
if args.i2v_mode and args.i2v_condition_type == "latent_concat":
in_channels = args.latent_channels * 2 + 1
image_embed_interleave = 2
elif args.i2v_mode and args.i2v_condition_type == "token_replace":
in_channels = args.latent_channels
image_embed_interleave = 4
else:
in_channels = args.latent_channels
image_embed_interleave = 1
out_channels = args.latent_channels
if args.embedded_cfg_scale:
factor_kwargs["guidance_embed"] = True
model = load_model(
args,
in_channels=in_channels,
out_channels=out_channels,
factor_kwargs=factor_kwargs,
)
if args.use_fp8:
convert_fp8_linear(model, args.dit_weight, original_dtype=PRECISION_TO_TYPE[args.precision])
model = model.to(device)
model = Inference.load_state_dict(args, model, pretrained_model_path)
model.eval()
# VAE
vae, _, s_ratio, t_ratio = load_vae(
args.vae,
args.vae_precision,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
# Text encoder
if args.i2v_mode:
args.text_encoder = "llm-i2v"
args.tokenizer = "llm-i2v"
args.prompt_template = "dit-llm-encode-i2v"
args.prompt_template_video = "dit-llm-encode-video-i2v"
if args.prompt_template_video is not None:
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
elif args.prompt_template is not None:
crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
else:
crop_start = 0
max_length = args.text_len + crop_start
prompt_template = PROMPT_TEMPLATE[args.prompt_template] if args.prompt_template is not None else None
prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] if args.prompt_template_video is not None else None
text_encoder = TextEncoder(
text_encoder_type=args.text_encoder,
max_length=max_length,
text_encoder_precision=args.text_encoder_precision,
tokenizer_type=args.tokenizer,
i2v_mode=args.i2v_mode,
prompt_template=prompt_template,
prompt_template_video=prompt_template_video,
hidden_state_skip_layer=args.hidden_state_skip_layer,
apply_final_norm=args.apply_final_norm,
reproduce=args.reproduce,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
image_embed_interleave=image_embed_interleave
)
text_encoder_2 = None
if args.text_encoder_2 is not None:
text_encoder_2 = TextEncoder(
text_encoder_type=args.text_encoder_2,
max_length=args.text_len_2,
text_encoder_precision=args.text_encoder_precision_2,
tokenizer_type=args.tokenizer_2,
reproduce=args.reproduce,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
else:
# 20250316 pftq: Initialize as None on non-zero ranks
model = None
vae = None
vae_kwargs = None
text_encoder = None
text_encoder_2 = None
# 20250316 pftq: Broadcast models to all ranks
if world_size > 1:
logger.info(f"Rank {rank}: Starting broadcast synchronization")
dist.barrier() # Ensure rank 0 finishes loading before broadcasting
if rank != 0:
# Reconstruct model skeleton on non-zero ranks
factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]}
if args.i2v_mode and args.i2v_condition_type == "latent_concat":
in_channels = args.latent_channels * 2 + 1
image_embed_interleave = 2
elif args.i2v_mode and args.i2v_condition_type == "token_replace":
in_channels = args.latent_channels
image_embed_interleave = 4
else:
in_channels = args.latent_channels
image_embed_interleave = 1
out_channels = args.latent_channels
if args.embedded_cfg_scale:
factor_kwargs["guidance_embed"] = True
model = load_model(args, in_channels=in_channels, out_channels=out_channels, factor_kwargs=factor_kwargs).to(device)
vae, _, s_ratio, t_ratio = load_vae(args.vae, args.vae_precision, logger=logger, device=device if not args.use_cpu_offload else "cpu")
vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
vae = vae.to(device)
if args.i2v_mode:
args.text_encoder = "llm-i2v"
args.tokenizer = "llm-i2v"
args.prompt_template = "dit-llm-encode-i2v"
args.prompt_template_video = "dit-llm-encode-video-i2v"
if args.prompt_template_video is not None:
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
elif args.prompt_template is not None:
crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
else:
crop_start = 0
max_length = args.text_len + crop_start
prompt_template = PROMPT_TEMPLATE[args.prompt_template] if args.prompt_template is not None else None
prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] if args.prompt_template_video is not None else None
text_encoder = TextEncoder(
text_encoder_type=args.text_encoder,
max_length=max_length,
text_encoder_precision=args.text_encoder_precision,
tokenizer_type=args.tokenizer,
i2v_mode=args.i2v_mode,
prompt_template=prompt_template,
prompt_template_video=prompt_template_video,
hidden_state_skip_layer=args.hidden_state_skip_layer,
apply_final_norm=args.apply_final_norm,
reproduce=args.reproduce,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
image_embed_interleave=image_embed_interleave
).to(device)
text_encoder_2 = None
if args.text_encoder_2 is not None:
text_encoder_2 = TextEncoder(
text_encoder_type=args.text_encoder_2,
max_length=args.text_len_2,
text_encoder_precision=args.text_encoder_precision_2,
tokenizer_type=args.tokenizer_2,
reproduce=args.reproduce,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
).to(device)
# Broadcast model parameters with logging
logger.info(f"Rank {rank}: Broadcasting model parameters")
for param in model.parameters():
dist.broadcast(param.data, src=0)
model.eval()
logger.info(f"Rank {rank}: Broadcasting VAE parameters")
for param in vae.parameters():
dist.broadcast(param.data, src=0)
# 20250316 pftq: Use broadcast_object_list for vae_kwargs
logger.info(f"Rank {rank}: Broadcasting vae_kwargs")
vae_kwargs_list = [vae_kwargs] if rank == 0 else [None]
dist.broadcast_object_list(vae_kwargs_list, src=0)
vae_kwargs = vae_kwargs_list[0]
logger.info(f"Rank {rank}: Broadcasting text_encoder parameters")
for param in text_encoder.parameters():
dist.broadcast(param.data, src=0)
if text_encoder_2 is not None:
logger.info(f"Rank {rank}: Broadcasting text_encoder_2 parameters")
for param in text_encoder_2.parameters():
dist.broadcast(param.data, src=0)
return cls(
args=args,
vae=vae,
vae_kwargs=vae_kwargs,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
model=model,
use_cpu_offload=args.use_cpu_offload,
device=device,
logger=logger,
parallel_args=parallel_args
)
@staticmethod
def load_state_dict(args, model, pretrained_model_path):
load_key = args.load_key
if args.i2v_mode:
dit_weight = Path(args.i2v_dit_weight)
else:
dit_weight = Path(args.dit_weight)
if dit_weight is None:
model_dir = pretrained_model_path / f"t2v_{args.model_resolution}"
files = list(model_dir.glob("*.pt"))
if len(files) == 0:
raise ValueError(f"No model weights found in {model_dir}")
if str(files[0]).startswith("pytorch_model_"):
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
bare_model = True
elif any(str(f).endswith("_model_states.pt") for f in files):
files = [f for f in files if str(f).endswith("_model_states.pt")]
model_path = files[0]
if len(files) > 1:
logger.warning(f"Multiple model weights found in {dit_weight}, using {model_path}")
bare_model = False
else:
raise ValueError(f"Invalid model path: {dit_weight} with unrecognized weight format")
else:
if dit_weight.is_dir():
files = list(dit_weight.glob("*.pt"))
if len(files) == 0:
raise ValueError(f"No model weights found in {dit_weight}")
if str(files[0]).startswith("pytorch_model_"):
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
bare_model = True
elif any(str(f).endswith("_model_states.pt") for f in files):
files = [f for f in files if str(f).endswith("_model_states.pt")]
model_path = files[0]
if len(files) > 1:
logger.warning(f"Multiple model weights found in {dit_weight}, using {model_path}")
bare_model = False
else:
raise ValueError(f"Invalid model path: {dit_weight} with unrecognized weight format")
elif dit_weight.is_file():
model_path = dit_weight
bare_model = "unknown"
else:
raise ValueError(f"Invalid model path: {dit_weight}")
if not model_path.exists():
raise ValueError(f"model_path not exists: {model_path}")
logger.info(f"Loading torch model {model_path}...")
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict):
bare_model = False
if bare_model is False:
if load_key in state_dict:
state_dict = state_dict[load_key]
else:
raise KeyError(f"Missing key: `{load_key}` in the checkpoint: {model_path}")
model.load_state_dict(state_dict, strict=True)
return model
@staticmethod
def parse_size(size):
if isinstance(size, int):
size = [size]
if not isinstance(size, (list, tuple)):
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
if len(size) == 1:
size = [size[0], size[0]]
if len(size) != 2:
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
return size
class HunyuanVideoSampler(Inference):
def __init__(
self,
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=None,
pipeline=None,
use_cpu_offload=False,
device=0,
logger=None,
parallel_args=None
):
super().__init__(
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=text_encoder_2,
pipeline=pipeline,
use_cpu_offload=use_cpu_offload,
device=device,
logger=logger,
parallel_args=parallel_args
)
self.pipeline = self.load_diffusion_pipeline(
args=args,
vae=self.vae,
text_encoder=self.text_encoder,
text_encoder_2=self.text_encoder_2,
model=self.model,
device=self.device,
)
if args.i2v_mode:
self.default_negative_prompt = NEGATIVE_PROMPT_I2V
if args.use_lora:
self.pipeline = load_lora_for_pipeline(
self.pipeline, args.lora_path, LORA_PREFIX_TRANSFORMER="Hunyuan_video_I2V_lora", alpha=args.lora_scale,
device=self.device,
is_parallel=(self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1))
logger.info(f"load lora {args.lora_path} into pipeline, lora scale is {args.lora_scale}.")
else:
self.default_negative_prompt = NEGATIVE_PROMPT
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
parallelize_transformer(self.pipeline)
def load_diffusion_pipeline(
self,
args,
vae,
text_encoder,
text_encoder_2,
model,
scheduler=None,
device=None,
progress_bar_config=None,
):
if scheduler is None:
if args.denoise_type == "flow":
scheduler = FlowMatchDiscreteScheduler(
shift=args.flow_shift,
reverse=args.flow_reverse,
solver=args.flow_solver,
)
else:
raise ValueError(f"Invalid denoise type {args.denoise_type}")
pipeline = HunyuanVideoPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
transformer=model,
scheduler=scheduler,
progress_bar_config=progress_bar_config,
args=args,
)
if self.use_cpu_offload:
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to(device)
return pipeline
# 20250317 pftq: Modified to use Riflex when >192 frames
def get_rotary_pos_embed(self, video_length, height, width):
target_ndim = 3
ndim = 5 - 2 # B, C, F, H, W -> F, H, W
# Compute latent sizes based on VAE type
if "884" in self.args.vae:
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
elif "888" in self.args.vae:
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
else:
latents_size = [video_length, height // 8, width // 8]
# Compute rope sizes
if isinstance(self.model.patch_size, int):
assert all(s % self.model.patch_size == 0 for s in latents_size), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [s // self.model.patch_size for s in latents_size]
elif isinstance(self.model.patch_size, list):
assert all(
s % self.model.patch_size[idx] == 0
for idx, s in enumerate(latents_size)
), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)]
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # Pad time axis
# 20250316 pftq: Add RIFLEx logic for > 192 frames
L_test = rope_sizes[0] # Latent frames
L_train = 25 # Training length from HunyuanVideo
actual_num_frames = video_length # Use input video_length directly
head_dim = self.model.hidden_size // self.model.heads_num
rope_dim_list = self.model.rope_dim_list or [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) must equal head_dim"
if actual_num_frames > 192:
k = 2+((actual_num_frames + 3) // (4 * L_train))
k = max(4, min(8, k))
logger.debug(f"actual_num_frames = {actual_num_frames} > 192, RIFLEx applied with k = {k}")
# Compute positional grids for RIFLEx
axes_grids = [torch.arange(size, device=self.device, dtype=torch.float32) for size in rope_sizes]
grid = torch.meshgrid(*axes_grids, indexing="ij")
grid = torch.stack(grid, dim=0) # [3, t, h, w]
pos = grid.reshape(3, -1).t() # [t * h * w, 3]
# Apply RIFLEx to temporal dimension
freqs = []
for i in range(3):
if i == 0: # Temporal with RIFLEx
freqs_cos, freqs_sin = get_1d_rotary_pos_embed_riflex(
rope_dim_list[i],
pos[:, i],
theta=self.args.rope_theta,
use_real=True,
k=k,
L_test=L_test
)
else: # Spatial with default RoPE
freqs_cos, freqs_sin = get_1d_rotary_pos_embed_riflex(
rope_dim_list[i],
pos[:, i],
theta=self.args.rope_theta,
use_real=True,
k=None,
L_test=None
)
freqs.append((freqs_cos, freqs_sin))
logger.debug(f"freq[{i}] shape: {freqs_cos.shape}, device: {freqs_cos.device}")
freqs_cos = torch.cat([f[0] for f in freqs], dim=1)
freqs_sin = torch.cat([f[1] for f in freqs], dim=1)
logger.debug(f"freqs_cos shape: {freqs_cos.shape}, device: {freqs_cos.device}")
else:
# 20250316 pftq: Original code for <= 192 frames
logger.debug(f"actual_num_frames = {actual_num_frames} <= 192, using original RoPE")
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=self.args.rope_theta,
use_real=True,
theta_rescale_factor=1,
)
logger.debug(f"freqs_cos shape: {freqs_cos.shape}, device: {freqs_cos.device}")
return freqs_cos, freqs_sin
@torch.no_grad()
def predict(
self,
prompt,
height=192,
width=336,
video_length=129,
seed=None,
negative_prompt=None,
infer_steps=50,
guidance_scale=6.0,
flow_shift=5.0,
embedded_guidance_scale=None,
batch_size=1,
num_videos_per_prompt=1,
i2v_mode=False,
i2v_resolution="720p",
i2v_image_path=None,
i2v_condition_type=None,
i2v_stability=True,
ulysses_degree=1,
ring_degree=1,
**kwargs,
):
out_dict = dict()
if isinstance(seed, torch.Tensor):
seed = seed.tolist()
if seed is None:
seeds = [
random.randint(0, 1_000_000)
for _ in range(batch_size * num_videos_per_prompt)
]
elif isinstance(seed, int):
seeds = [
seed + i
for _ in range(batch_size)
for i in range(num_videos_per_prompt)
]
elif isinstance(seed, (list, tuple)):
if len(seed) == batch_size:
seeds = [
int(seed[i]) + j
for i in range(batch_size)
for j in range(num_videos_per_prompt)
]
elif len(seed) == batch_size * num_videos_per_prompt:
seeds = [int(s) for s in seed]
else:
raise ValueError(
f"Length of seed must be equal to number of prompt(batch_size) or "
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
)
else:
raise ValueError(
f"Seed must be an integer, a list of integers, or None, got {seed}."
)
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
out_dict["seeds"] = seeds
if width <= 0 or height <= 0 or video_length <= 0:
raise ValueError(
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}"
)
if (video_length - 1) % 4 != 0:
raise ValueError(
f"`video_length-1` must be a multiple of 4, got {video_length}"
)
logger.info(
f"Input (height, width, video_length) = ({height}, {width}, {video_length})"
)
target_height = align_to(height, 16)
target_width = align_to(width, 16)
target_video_length = video_length
out_dict["size"] = (target_height, target_width, target_video_length)
if not isinstance(prompt, str):
raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
prompt = [prompt.strip()]
if negative_prompt is None or negative_prompt == "":
negative_prompt = self.default_negative_prompt
if guidance_scale == 1.0:
negative_prompt = ""
if not isinstance(negative_prompt, str):
raise TypeError(
f"`negative_prompt` must be a string, but got {type(negative_prompt)}"
)
negative_prompt = [negative_prompt.strip()]
scheduler = FlowMatchDiscreteScheduler(
shift=flow_shift,
reverse=self.args.flow_reverse,
solver=self.args.flow_solver
)
self.pipeline.scheduler = scheduler
img_latents = None
semantic_images = None
if i2v_mode:
if i2v_resolution == "720p":
bucket_hw_base_size = 960
elif i2v_resolution == "540p":
bucket_hw_base_size = 720
elif i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
semantic_images = [Image.open(i2v_image_path).convert('RGB')]
origin_size = semantic_images[0].size
crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h)/float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
if ulysses_degree != 1 or ring_degree != 1:
diviser = get_sequence_parallel_world_size() * 8 * 2
if closest_size[0] % diviser != 0 and closest_size[1] % diviser != 0:
xdit_crop_size_list = list(filter(lambda x: x[0] % diviser == 0 or x[1] % diviser == 0, crop_size_list))
xdit_aspect_ratios = np.array([round(float(h)/float(w), 5) for h, w in xdit_crop_size_list])
xdit_closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], xdit_aspect_ratios, xdit_crop_size_list)
assert os.getenv("ALLOW_RESIZE_FOR_SP") is not None, \
f"The image resolution is {origin_size}. " \
f"Based on the input i2v-resultion ({i2v_resolution}), " \
f"the closest ratio of resolution supported by HunyuanVideo-I2V is ({closest_size[1]}, {closest_size[0]}), " \
f"the latent resolution of which is ({closest_size[1] // 16}, {closest_size[0] // 16}). " \
f"You run the program with {get_sequence_parallel_world_size()} GPUs " \
f"(SP degree={get_sequence_parallel_world_size()}). " \
f"However, neither of the width ({closest_size[1] // 16}) or the " \
f"height ({closest_size[0] // 16}) " \
f"is divisible by the SP degree ({get_sequence_parallel_world_size()}). " \
f"Please set ALLOW_RESIZE_FOR_SP=1 in the environment to allow xDiT to resize the image to {xdit_closest_size}. " \
f"If you do not want to resize the image, please try other SP degrees and rerun the program. "
logger.debug(f"xDiT resizes the input image to {xdit_closest_size}.")
closest_size = xdit_closest_size
resize_param = min(closest_size)
center_crop_param = closest_size
ref_image_transform = transforms.Compose([
transforms.Resize(resize_param),
transforms.CenterCrop(center_crop_param),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
img_latents = self.pipeline.vae.encode(semantic_image_pixel_values).latent_dist.mode()
img_latents.mul_(self.pipeline.vae.config.scaling_factor)
target_height, target_width = closest_size
freqs_cos, freqs_sin = self.get_rotary_pos_embed(
target_video_length, target_height, target_width
)
n_tokens = freqs_cos.shape[0]
debug_str = f"""
height: {target_height}
width: {target_width}
video_length: {target_video_length}
prompt: {prompt}
neg_prompt: {negative_prompt}
seed: {seed}
infer_steps: {infer_steps}
num_videos_per_prompt: {num_videos_per_prompt}
guidance_scale: {guidance_scale}
n_tokens: {n_tokens}
flow_shift: {flow_shift}
embedded_guidance_scale: {embedded_guidance_scale}
i2v_stability: {i2v_stability}"""
if ulysses_degree != 1 or ring_degree != 1:
debug_str += f"""
ulysses_degree: {ulysses_degree}
ring_degree: {ring_degree}"""
logger.debug(debug_str)
start_time = time.time()
samples = self.pipeline(
prompt=prompt,
height=target_height,
width=target_width,
video_length=target_video_length,
num_inference_steps=infer_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
generator=generator,
output_type="pil",
freqs_cis=(freqs_cos, freqs_sin),
n_tokens=n_tokens,
embedded_guidance_scale=embedded_guidance_scale,
data_type="video" if target_video_length > 1 else "image",
is_progress_bar=True,
vae_ver=self.args.vae,
enable_tiling=self.args.vae_tiling,
i2v_mode=i2v_mode,
i2v_condition_type=i2v_condition_type,
i2v_stability=i2v_stability,
img_latents=img_latents,
semantic_images=semantic_images,
)[0]
out_dict["samples"] = samples
out_dict["prompts"] = prompt
gen_time = time.time() - start_time
logger.info(f"Success, time: {gen_time}")
return out_dict
from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
def load_model(args, in_channels, out_channels, factor_kwargs):
"""load hunyuan video model
Args:
args (dict): model args
in_channels (int): input channels number
out_channels (int): output channels number
factor_kwargs (dict): factor kwargs
Returns:
model (nn.Module): The hunyuan video model
"""
if args.model in HUNYUAN_VIDEO_CONFIG.keys():
model = HYVideoDiffusionTransformer(
args,
in_channels=in_channels,
out_channels=out_channels,
**HUNYUAN_VIDEO_CONFIG[args.model],
**factor_kwargs,
)
return model
else:
raise NotImplementedError()
import torch.nn as nn
def get_activation_layer(act_type):
"""get activation layer
Args:
act_type (str): the activation type
Returns:
torch.nn.functional: the activation layer
"""
if act_type == "gelu":
return lambda: nn.GELU()
elif act_type == "gelu_tanh":
# Approximate `tanh` requires torch >= 1.13
return lambda: nn.GELU(approximate="tanh")
elif act_type == "relu":
return nn.ReLU
elif act_type == "silu":
return nn.SiLU
else:
raise ValueError(f"Unknown activation type: {act_type}")
import importlib.metadata
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
flash_attn = None
flash_attn_varlen_func = None
_flash_attn_forward = None
MEMORY_LAYOUT = {
"flash": (
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
lambda x: x,
),
"torch": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
"vanilla": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
}
def get_cu_seqlens(text_mask, img_len):
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
Args:
text_mask (torch.Tensor): the mask of text
img_len (int): the length of image
Returns:
torch.Tensor: the calculated cu_seqlens for flash attention
"""
batch_size = text_mask.shape[0]
text_len = text_mask.sum(dim=1)
max_len = text_mask.shape[1] + img_len
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
for i in range(batch_size):
s = text_len[i] + img_len
s1 = i * max_len + s
s2 = (i + 1) * max_len
cu_seqlens[2 * i + 1] = s1
cu_seqlens[2 * i + 2] = s2
return cu_seqlens
def attention(
q,
k,
v,
mode="flash",
drop_rate=0,
attn_mask=None,
causal=False,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
batch_size=1,
):
"""
Perform QKV self attention.
Args:
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
drop_rate (float): Dropout rate in attention map. (default: 0)
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
(default: None)
causal (bool): Whether to use causal attention. (default: False)
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into q.
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into kv.
max_seqlen_q (int): The maximum sequence length in the batch of q.
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
Returns:
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
"""
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
q = pre_attn_layout(q)
k = pre_attn_layout(k)
v = pre_attn_layout(v)
if mode == "torch":
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
)
elif mode == "flash":
x = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
)
# x with shape [(bxs), a, d]
x = x.view(
batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
) # reshape x to [b, s, a, d]
elif mode == "vanilla":
scale_factor = 1 / math.sqrt(q.size(-1))
b, a, s, _ = q.shape
s1 = k.size(2)
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
if causal:
# Only applied to self attention
assert (
attn_mask is None
), "Causal mask and attn_mask cannot be used together"
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
diagonal=0
)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(q.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
attn = (q @ k.transpose(-2, -1)) * scale_factor
attn += attn_bias
attn = attn.softmax(dim=-1)
attn = torch.dropout(attn, p=drop_rate, train=True)
x = attn @ v
else:
raise NotImplementedError(f"Unsupported attention mode: {mode}")
x = post_attn_layout(x)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
def parallel_attention(
hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len,
img_kv_len,
cu_seqlens_q,
cu_seqlens_kv
):
attn1 = hybrid_seq_parallel_attn(
None,
q[:, :img_q_len, :, :],
k[:, :img_kv_len, :, :],
v[:, :img_kv_len, :, :],
dropout_p=0.0,
causal=False,
joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]],
joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]],
joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
joint_strategy="rear",
)
if flash_attn.__version__ >= '2.7.0':
attn2, *_ = _flash_attn_forward(
q[:,cu_seqlens_q[1]:],
k[:,cu_seqlens_kv[1]:],
v[:,cu_seqlens_kv[1]:],
dropout_p=0.0,
softmax_scale=q.shape[-1] ** (-0.5),
causal=False,
window_size_left=-1,
window_size_right=-1,
softcap=0.0,
alibi_slopes=None,
return_softmax=False,
)
else:
attn2, *_ = _flash_attn_forward(
q[:,cu_seqlens_q[1]:],
k[:,cu_seqlens_kv[1]:],
v[:,cu_seqlens_kv[1]:],
dropout_p=0.0,
softmax_scale=q.shape[-1] ** (-0.5),
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
return_softmax=False,
)
attn = torch.cat([attn1, attn2], dim=1)
b, s, a, d = attn.shape
attn = attn.reshape(b, s, -1)
return attn
import math
import torch
import torch.nn as nn
from einops import rearrange, repeat
from ..utils.helpers import to_2tuple
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding
Image to Patch Embedding using Conv2d
A convolution based approach to patchifying a 2D image w/ embedding projection.
Based on the impl in https://github.com/google-research/vision_transformer
Hacked together by / Copyright 2020 Ross Wightman
Remove the _assert function in forward function to be compatible with multi-resolution images.
"""
def __init__(
self,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.flatten = flatten
self.proj = nn.Conv3d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
**factory_kwargs
)
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
if bias:
nn.init.zeros_(self.proj.bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class TextProjection(nn.Module):
"""
Projects text embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.linear_1 = nn.Linear(
in_features=in_channels,
out_features=hidden_size,
bias=True,
**factory_kwargs
)
self.act_1 = act_layer()
self.linear_2 = nn.Linear(
in_features=hidden_size,
out_features=hidden_size,
bias=True,
**factory_kwargs
)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
dim (int): the dimension of the output.
max_period (int): controls the minimum frequency of the embeddings.
Returns:
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(
self,
hidden_size,
act_layer,
frequency_embedding_size=256,
max_period=10000,
out_size=None,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
if out_size is None:
out_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
),
act_layer(),
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
)
nn.init.normal_(self.mlp[0].weight, std=0.02)
nn.init.normal_(self.mlp[2].weight, std=0.02)
def forward(self, t):
t_freq = timestep_embedding(
t, self.frequency_embedding_size, self.max_period
).type(self.mlp[0].weight.dtype)
t_emb = self.mlp(t_freq)
return t_emb
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
_bits = torch.tensor(bits)
_mantissa_bit = torch.tensor(mantissa_bit)
_sign_bits = torch.tensor(sign_bits)
M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
E = _bits - _sign_bits - M
bias = 2 ** (E - 1) - 1
mantissa = 1
for i in range(mantissa_bit - 1):
mantissa += 1 / (2 ** (i+1))
maxval = mantissa * 2 ** (2**E - 1 - bias)
return maxval
def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
"""
Default is E4M3.
"""
bits = torch.tensor(bits)
mantissa_bit = torch.tensor(mantissa_bit)
sign_bits = torch.tensor(sign_bits)
M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
E = bits - sign_bits - M
bias = 2 ** (E - 1) - 1
mantissa = 1
for i in range(mantissa_bit - 1):
mantissa += 1 / (2 ** (i+1))
maxval = mantissa * 2 ** (2**E - 1 - bias)
minval = - maxval
minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
input_clamp = torch.min(torch.max(x, minval), maxval)
log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
# dequant
qdq_out = torch.round(input_clamp / log_scales) * log_scales
return qdq_out, log_scales
def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
for i in range(len(x.shape) - 1):
scale = scale.unsqueeze(-1)
new_x = x / scale
quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
return quant_dequant_x, scale, log_scales
def fp8_activation_dequant(qdq_out, scale, dtype):
qdq_out = qdq_out.type(dtype)
quant_dequant_x = qdq_out * scale.to(dtype)
return quant_dequant_x
def fp8_linear_forward(cls, original_dtype, input):
weight_dtype = cls.weight.dtype
#####
if cls.weight.dtype != torch.float8_e4m3fn:
maxval = get_fp_maxval()
scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
linear_weight = linear_weight.to(torch.float8_e4m3fn)
weight_dtype = linear_weight.dtype
else:
scale = cls.fp8_scale.to(cls.weight.device)
linear_weight = cls.weight
#####
if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
if True or len(input.shape) == 3:
cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
if cls.bias != None:
output = F.linear(input, cls_dequant, cls.bias)
else:
output = F.linear(input, cls_dequant)
return output
else:
return cls.original_forward(input.to(original_dtype))
else:
return cls.original_forward(input)
def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}):
setattr(module, "fp8_matmul_enabled", True)
# loading fp8 mapping file
fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
if os.path.exists(fp8_map_path):
fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
else:
raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")
fp8_layers = []
for key, layer in module.named_modules():
if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
fp8_layers.append(key)
original_forward = layer.forward
layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
setattr(layer, "original_forward", original_forward)
setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))
# Modified from timm library:
# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
from functools import partial
import torch
import torch.nn as nn
from .modulate_layers import modulate
from ..utils.helpers import to_2tuple
class MLP(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_channels,
hidden_channels=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
out_features = out_features or in_channels
hidden_channels = hidden_channels or in_channels
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(
in_channels, hidden_channels, bias=bias[0], **factory_kwargs
)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = (
norm_layer(hidden_channels, **factory_kwargs)
if norm_layer is not None
else nn.Identity()
)
self.fc2 = linear_layer(
hidden_channels, out_features, bias=bias[1], **factory_kwargs
)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
#
class MLPEmbedder(nn.Module):
"""copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class FinalLayer(nn.Module):
"""The final layer of DiT."""
def __init__(
self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
# Just use LayerNorm for the final layer
self.norm_final = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
if isinstance(patch_size, int):
self.linear = nn.Linear(
hidden_size,
patch_size * patch_size * out_channels,
bias=True,
**factory_kwargs
)
else:
self.linear = nn.Linear(
hidden_size,
patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
bias=True,
)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
# Here we don't distinguish between the modulate types. Just use the simple one.
self.adaLN_modulation = nn.Sequential(
act_layer(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
)
# Zero-initialize the modulation
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift=shift, scale=scale)
x = self.linear(x)
return x
from typing import Any, List, Tuple, Optional, Union, Dict
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
import torch.utils
import torch.utils.checkpoint
from .activation_layers import get_activation_layer
from .norm_layers import get_norm_layer
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
from .attenion import attention, parallel_attention, get_cu_seqlens
from .posemb_layers import apply_rotary_emb
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
from .modulate_layers import ModulateDiT, modulate, apply_gate, ckpt_wrapper
from .token_refiner import SingleTokenRefiner
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal dit block with seperate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qkv_bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.img_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.img_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.img_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.txt_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
self.txt_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.txt_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
freqs_cis: tuple = None,
condition_type: str = None,
token_replace_vec: torch.Tensor = None,
frist_frame_token_num: int = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if condition_type == "token_replace":
img_mod1, token_replace_img_mod1 = self.img_mod(vec, condition_type=condition_type, \
token_replace_vec=token_replace_vec)
(img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate) = img_mod1.chunk(6, dim=-1)
(tr_img_mod1_shift,
tr_img_mod1_scale,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate) = token_replace_img_mod1.chunk(6, dim=-1)
else:
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(vec).chunk(6, dim=-1)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(vec).chunk(6, dim=-1)
# Prepare image for attention.
img_modulated = self.img_norm1(img)
if condition_type == "token_replace":
img_modulated = modulate(
img_modulated, shift=img_mod1_shift, scale=img_mod1_scale, condition_type=condition_type,
tr_shift=tr_img_mod1_shift, tr_scale=tr_img_mod1_scale,
frist_frame_token_num=frist_frame_token_num
)
else:
img_modulated = modulate(
img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
)
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# Apply RoPE if needed.
if freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
# Prepare txt for attention.
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(
txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed.
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
# Run actual attention.
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
assert (
cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
# attention computation start
if not self.hybrid_seq_parallel_attn:
attn = attention(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
batch_size=img_k.shape[0],
)
else:
attn = parallel_attention(
self.hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv
)
# attention computation end
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
# Calculate the img bloks.
if condition_type == "token_replace":
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate, condition_type=condition_type,
tr_gate=tr_img_mod1_gate, frist_frame_token_num=frist_frame_token_num)
img = img + apply_gate(
self.img_mlp(
modulate(
self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale, condition_type=condition_type,
tr_shift=tr_img_mod2_shift, tr_scale=tr_img_mod2_scale, frist_frame_token_num=frist_frame_token_num
)
),
gate=img_mod2_gate, condition_type=condition_type,
tr_gate=tr_img_mod2_gate, frist_frame_token_num=frist_frame_token_num
)
else:
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
img = img + apply_gate(
self.img_mlp(
modulate(
self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
)
),
gate=img_mod2_gate,
)
# Calculate the txt bloks.
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
txt = txt + apply_gate(
self.txt_mlp(
modulate(
self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
)
),
gate=txt_mod2_gate,
)
return img, txt
class MMSingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
Also refer to (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
self.scale = qk_scale or head_dim ** -0.5
# qkv and mlp_in
self.linear1 = nn.Linear(
hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
)
# proj and mlp_out
self.linear2 = nn.Linear(
hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.pre_norm = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.mlp_act = get_activation_layer(mlp_act_type)()
self.modulation = ModulateDiT(
hidden_size,
factor=3,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
condition_type: str = None,
token_replace_vec: torch.Tensor = None,
frist_frame_token_num: int = None,
) -> torch.Tensor:
if condition_type == "token_replace":
mod, tr_mod = self.modulation(vec,
condition_type=condition_type,
token_replace_vec=token_replace_vec)
(mod_shift,
mod_scale,
mod_gate) = mod.chunk(3, dim=-1)
(tr_mod_shift,
tr_mod_scale,
tr_mod_gate) = tr_mod.chunk(3, dim=-1)
else:
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
if condition_type == "token_replace":
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale, condition_type=condition_type,
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, frist_frame_token_num=frist_frame_token_num)
else:
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
qkv, mlp = torch.split(
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
# Apply RoPE if needed.
if freqs_cis is not None:
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
# Compute attention.
assert (
cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
# attention computation start
if not self.hybrid_seq_parallel_attn:
attn = attention(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
batch_size=x.shape[0],
)
else:
attn = parallel_attention(
self.hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv
)
# attention computation end
# Compute activation in mlp stream, cat again and run second linear layer.
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
if condition_type == "token_replace":
output = x + apply_gate(output, gate=mod_gate, condition_type=condition_type,
tr_gate=tr_mod_gate, frist_frame_token_num=frist_frame_token_num)
return output
else:
return x + apply_gate(output, gate=mod_gate)
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
"""
HunyuanVideo Transformer backbone
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
Reference:
[1] Flux.1: https://github.com/black-forest-labs/flux
[2] MMDiT: http://arxiv.org/abs/2403.03206
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
patch_size: list
The size of the patch.
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
hidden_size: int
The hidden size of the transformer backbone.
heads_num: int
The number of attention heads.
mlp_width_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
mlp_act_type: str
The activation function of the MLP in the transformer block.
depth_double_blocks: int
The number of transformer blocks in the double blocks.
depth_single_blocks: int
The number of transformer blocks in the single blocks.
rope_dim_list: list
The dimension of the rotary embedding for t, h, w.
qkv_bias: bool
Whether to use bias in the qkv linear layer.
qk_norm: bool
Whether to use qk norm.
qk_norm_type: str
The type of qk norm.
guidance_embed: bool
Whether to use guidance embedding for distillation.
text_projection: str
The type of the text projection, default is single_refiner.
use_attention_mask: bool
Whether to use attention mask for text encoder.
dtype: torch.dtype
The dtype of the model.
device: torch.device
The device of the model.
"""
@register_to_config
def __init__(
self,
args: Any,
patch_size: list = [1, 2, 2],
in_channels: int = 4, # Should be VAE.config.latent_channels.
out_channels: int = None,
hidden_size: int = 3072,
heads_num: int = 24,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
mm_double_blocks_depth: int = 20,
mm_single_blocks_depth: int = 40,
rope_dim_list: List[int] = [16, 56, 56],
qkv_bias: bool = True,
qk_norm: bool = True,
qk_norm_type: str = "rms",
guidance_embed: bool = False, # For modulation.
text_projection: str = "single_refiner",
use_attention_mask: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.unpatchify_channels = self.out_channels
self.guidance_embed = guidance_embed
self.rope_dim_list = rope_dim_list
self.i2v_condition_type = args.i2v_condition_type
# Text projection. Default to linear projection.
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
self.use_attention_mask = use_attention_mask
self.text_projection = text_projection
self.text_states_dim = args.text_states_dim
self.text_states_dim_2 = args.text_states_dim_2
# Gradient checkpoint.
self.gradient_checkpoint = args.gradient_checkpoint
self.gradient_checkpoint_layers = args.gradient_checkpoint_layers
if self.gradient_checkpoint:
assert self.gradient_checkpoint_layers <= mm_double_blocks_depth + mm_single_blocks_depth, \
f"Gradient checkpoint layers must be less or equal than the depth of the model. " \
f"Got gradient_checkpoint_layers={self.gradient_checkpoint_layers} and " \
f"depth={mm_double_blocks_depth + mm_single_blocks_depth}."
if hidden_size % heads_num != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
)
pe_dim = hidden_size // heads_num
if sum(rope_dim_list) != pe_dim:
raise ValueError(
f"Got {rope_dim_list} but expected positional dim {pe_dim}"
)
self.hidden_size = hidden_size
self.heads_num = heads_num
# image projection
self.img_in = PatchEmbed(
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
)
# text projection
if self.text_projection == "linear":
self.txt_in = TextProjection(
self.text_states_dim,
self.hidden_size,
get_activation_layer("silu"),
**factory_kwargs,
)
elif self.text_projection == "single_refiner":
self.txt_in = SingleTokenRefiner(
self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
# time modulation
self.time_in = TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
# text modulation
self.vector_in = MLPEmbedder(
self.text_states_dim_2, self.hidden_size, **factory_kwargs
)
# guidance modulation
self.guidance_in = (
TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
if guidance_embed
else None
)
# double blocks
self.double_blocks = nn.ModuleList(
[
MMDoubleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,
)
for _ in range(mm_double_blocks_depth)
]
)
# single blocks
self.single_blocks = nn.ModuleList(
[
MMSingleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs,
)
for _ in range(mm_single_blocks_depth)
]
)
self.final_layer = FinalLayer(
self.hidden_size,
self.patch_size,
self.out_channels,
get_activation_layer("silu"),
**factory_kwargs,
)
def enable_deterministic(self):
for block in self.double_blocks:
block.enable_deterministic()
for block in self.single_blocks:
block.enable_deterministic()
def disable_deterministic(self):
for block in self.double_blocks:
block.disable_deterministic()
for block in self.single_blocks:
block.disable_deterministic()
def forward(
self,
x: torch.Tensor,
t: torch.Tensor, # Should be in range(0, 1000).
text_states: torch.Tensor = None,
text_mask: torch.Tensor = None, # Now we don't use it.
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
out = {}
img = x
txt = text_states
_, _, ot, oh, ow = x.shape
tt, th, tw = (
ot // self.patch_size[0],
oh // self.patch_size[1],
ow // self.patch_size[2],
)
# Prepare modulation vectors.
vec = self.time_in(t)
if self.i2v_condition_type == "token_replace":
token_replace_t = torch.zeros_like(t)
token_replace_vec = self.time_in(token_replace_t)
frist_frame_token_num = th * tw
else:
token_replace_vec = None
frist_frame_token_num = None
# text modulation
vec_2 = self.vector_in(text_states_2)
vec = vec + vec_2
if self.i2v_condition_type == "token_replace":
token_replace_vec = token_replace_vec + vec_2
# guidance modulation
if self.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
vec = vec + self.guidance_in(guidance)
# Embed image and text.
img = self.img_in(img)
if self.text_projection == "linear":
txt = self.txt_in(txt)
elif self.text_projection == "single_refiner":
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
# Compute cu_squlens and max_seqlen for flash attention
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
cu_seqlens_kv = cu_seqlens_q
max_seqlen_q = img_seq_len + txt_seq_len
max_seqlen_kv = max_seqlen_q
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
# --------------------- Pass through DiT blocks ------------------------
for layer_num, block in enumerate(self.double_blocks):
double_block_args = [
img,
txt,
vec,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
freqs_cis,
self.i2v_condition_type,
token_replace_vec,
frist_frame_token_num,
]
if self.training and self.gradient_checkpoint and \
(self.gradient_checkpoint_layers == -1 or layer_num < self.gradient_checkpoint_layers):
# print(f'gradient checkpointing...')
img, txt = torch.utils.checkpoint.checkpoint(ckpt_wrapper(block), *double_block_args, use_reentrant=False)
else:
img, txt = block(*double_block_args)
# Merge txt and img to pass through single stream blocks.
x = torch.cat((img, txt), 1)
if len(self.single_blocks) > 0:
for _, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
txt_seq_len,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
(freqs_cos, freqs_sin),
self.i2v_condition_type,
token_replace_vec,
frist_frame_token_num,
]
if self.training and self.gradient_checkpoint and \
(self.gradient_checkpoint_layers == -1 or layer_num + len(self.double_blocks) < self.gradient_checkpoint_layers):
x = torch.utils.checkpoint.checkpoint(ckpt_wrapper(block), *single_block_args, use_reentrant=False)
else:
x = block(*single_block_args)
img = x[:, :img_seq_len, ...]
# ---------------------------- Final layer ------------------------------
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
img = self.unpatchify(img, tt, th, tw)
if return_dict:
out["x"] = img
return out
return img
def unpatchify(self, x, t, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def params_count(self):
counts = {
"double": sum(
[
sum(p.numel() for p in block.img_attn_qkv.parameters())
+ sum(p.numel() for p in block.img_attn_proj.parameters())
+ sum(p.numel() for p in block.img_mlp.parameters())
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
+ sum(p.numel() for p in block.txt_mlp.parameters())
for block in self.double_blocks
]
),
"single": sum(
[
sum(p.numel() for p in block.linear1.parameters())
+ sum(p.numel() for p in block.linear2.parameters())
for block in self.single_blocks
]
),
"total": sum(p.numel() for p in self.parameters()),
}
counts["attn+mlp"] = counts["double"] + counts["single"]
return counts
def set_input_tensor(self, input_tensor):
pass
#################################################################################
# HunyuanVideo Configs #
#################################################################################
HUNYUAN_VIDEO_CONFIG = {
"HYVideo-T/2": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
},
"HYVideo-T/2-cfgdistill": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
"guidance_embed": True,
},
"HYVideo-S/2": {
"mm_double_blocks_depth": 6,
"mm_single_blocks_depth": 12,
"rope_dim_list": [12, 42, 42],
"hidden_size": 480,
"heads_num": 5,
"mlp_width_ratio": 4,
},
}
from typing import Callable
import torch
import torch.nn as nn
import math
class ModulateDiT(nn.Module):
"""Modulation layer for DiT."""
def __init__(
self,
hidden_size: int,
factor: int,
act_layer: Callable,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.act = act_layer()
self.linear = nn.Linear(
hidden_size, factor * hidden_size, bias=True, **factory_kwargs
)
# Zero-initialize the modulation
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor, condition_type=None, token_replace_vec=None) -> torch.Tensor:
x_out = self.linear(self.act(x))
if condition_type == "token_replace":
x_token_replace_out = self.linear(self.act(token_replace_vec))
return x_out, x_token_replace_out
else:
return x_out
def modulate(x, shift=None, scale=None, condition_type=None,
tr_shift=None, tr_scale=None,
frist_frame_token_num=None):
"""modulate by shift and scale
Args:
x (torch.Tensor): input tensor.
shift (torch.Tensor, optional): shift tensor. Defaults to None.
scale (torch.Tensor, optional): scale tensor. Defaults to None.
Returns:
torch.Tensor: the output tensor after modulate.
"""
if condition_type == "token_replace":
x_zero = x[:, :frist_frame_token_num] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
x_orig = x[:, frist_frame_token_num:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = torch.concat((x_zero, x_orig), dim=1)
return x
else:
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False, condition_type=None, tr_gate=None, frist_frame_token_num=None):
"""AI is creating summary for apply_gate
Args:
x (torch.Tensor): input tensor.
gate (torch.Tensor, optional): gate tensor. Defaults to None.
tanh (bool, optional): whether to use tanh function. Defaults to False.
Returns:
torch.Tensor: the output tensor after apply gate.
"""
if condition_type == "token_replace":
if gate is None:
return x
if tanh:
x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1).tanh()
x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1).tanh()
x = torch.concat((x_zero, x_orig), dim=1)
return x
else:
x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1)
x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1)
x = torch.concat((x_zero, x_orig), dim=1)
return x
else:
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
def ckpt_wrapper(module):
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return nn.LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
import torch
from typing import Union, Tuple, List
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
#################################################################################
# Rotary Positional Embedding Functions #
#################################################################################
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
def reshape_for_broadcast(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
x: torch.Tensor,
head_first=False,
):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Notes:
When using FlashMHAModified, head_first should be False.
When using Attention, head_first should be True.
Args:
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
# freqs_cis: (cos, sin) in real space
if head_first:
assert freqs_cis[0].shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [
d if i == ndim - 2 or i == ndim - 1 else 1
for i, d in enumerate(x.shape)
]
else:
assert freqs_cis[0].shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
# freqs_cis: values in complex space
if head_first:
assert freqs_cis.shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [
d if i == ndim - 2 or i == ndim - 1 else 1
for i, d in enumerate(x.shape)
]
else:
assert freqs_cis.shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = (
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
# real * cos - imag * sin
# imag * cos + real * sin
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
else:
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], -1, 2)
) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
xq.device
) # [S, D//2] --> [1, S, 1, D//2]
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], -1, 2)
) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
use_real=False,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(
start, *args, dim=len(rope_dim_list)
) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(
rope_dim_list
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(
rope_dim_list
), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i],
grid[i].reshape(-1),
theta,
use_real=use_real,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
else:
emb = torch.cat(embs, dim=1) # (WHD, D/2)
return emb
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
use_real: bool = False,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(
torch.ones_like(freqs), freqs
) # complex64 # [S, D/2]
return freqs_cis
from typing import Optional
from einops import rearrange
import torch
import torch.nn as nn
from .activation_layers import get_activation_layer
from .attenion import attention
from .norm_layers import get_norm_layer
from .embed_layers import TimestepEmbedder, TextProjection
from .attenion import attention
from .mlp_layers import MLP
from .modulate_layers import modulate, apply_gate
class IndividualTokenRefinerBlock(nn.Module):
def __init__(
self,
hidden_size,
heads_num,
mlp_width_ratio: str = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
self.self_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.self_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
act_layer = get_activation_layer(act_type)
self.mlp = MLP(
in_channels=hidden_size,
hidden_channels=mlp_hidden_dim,
act_layer=act_layer,
drop=mlp_drop_rate,
**factory_kwargs,
)
self.adaLN_modulation = nn.Sequential(
act_layer(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
)
# Zero-initialize the modulation
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
attn_mask: torch.Tensor = None,
):
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn_qkv(norm_x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# Apply QK-Norm if needed
q = self.self_attn_q_norm(q).to(v)
k = self.self_attn_k_norm(k).to(v)
# Self-Attention
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
# FFN Layer
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
return x
class IndividualTokenRefiner(nn.Module):
def __init__(
self,
hidden_size,
heads_num,
depth,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.blocks = nn.ModuleList(
[
IndividualTokenRefinerBlock(
hidden_size=hidden_size,
heads_num=heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,
)
for _ in range(depth)
]
)
def forward(
self,
x: torch.Tensor,
c: torch.LongTensor,
mask: Optional[torch.Tensor] = None,
):
self_attn_mask = None
if mask is not None:
batch_size = mask.shape[0]
seq_len = mask.shape[1]
mask = mask.to(x.device)
# batch_size x 1 x seq_len x seq_len
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
1, 1, seq_len, 1
)
# batch_size x 1 x seq_len x seq_len
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
# avoids self-attention weight being NaN for padding tokens
self_attn_mask[:, :, :, 0] = True
for block in self.blocks:
x = block(x, c, self_attn_mask)
return x
class SingleTokenRefiner(nn.Module):
"""
A single token refiner block for llm text embedding refine.
"""
def __init__(
self,
in_channels,
hidden_size,
heads_num,
depth,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
attn_mode: str = "torch",
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.attn_mode = attn_mode
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
self.input_embedder = nn.Linear(
in_channels, hidden_size, bias=True, **factory_kwargs
)
act_layer = get_activation_layer(act_type)
# Build timestep embedding layer
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
# Build context embedding layer
self.c_embedder = TextProjection(
in_channels, hidden_size, act_layer, **factory_kwargs
)
self.individual_token_refiner = IndividualTokenRefiner(
hidden_size=hidden_size,
heads_num=heads_num,
depth=depth,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,
)
def forward(
self,
x: torch.Tensor,
t: torch.LongTensor,
mask: Optional[torch.LongTensor] = None,
):
timestep_aware_representations = self.t_embedder(t)
if mask is None:
context_aware_representations = x.mean(dim=1)
else:
mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
context_aware_representations = (x * mask_float).sum(
dim=1
) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask)
return x
from dataclasses import dataclass
from typing import Optional, Tuple
from copy import deepcopy
import torch
import torch.nn as nn
from transformers import (
CLIPTextModel,
CLIPTokenizer,
AutoTokenizer,
AutoModel,
LlavaForConditionalGeneration,
CLIPImageProcessor,
)
from transformers.utils import ModelOutput
from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH
from ..constants import PRECISION_TO_TYPE
def use_default(value, default):
return value if value is not None else default
def load_text_encoder(
text_encoder_type,
text_encoder_precision=None,
text_encoder_path=None,
logger=None,
device=None,
):
if text_encoder_path is None:
text_encoder_path = TEXT_ENCODER_PATH[text_encoder_type]
if logger is not None:
logger.info(
f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}"
)
if text_encoder_type == "clipL":
text_encoder = CLIPTextModel.from_pretrained(text_encoder_path)
text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
elif text_encoder_type == "llm":
text_encoder = AutoModel.from_pretrained(
text_encoder_path, low_cpu_mem_usage=True
)
text_encoder.final_layer_norm = text_encoder.norm
elif text_encoder_type == "llm-i2v":
text_encoder = LlavaForConditionalGeneration.from_pretrained(
text_encoder_path, low_cpu_mem_usage=True
)
else:
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
# from_pretrained will ensure that the model is in eval mode.
if text_encoder_precision is not None:
text_encoder = text_encoder.to(dtype=PRECISION_TO_TYPE[text_encoder_precision])
text_encoder.requires_grad_(False)
if logger is not None:
logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
if device is not None:
text_encoder = text_encoder.to(device)
return text_encoder, text_encoder_path
def load_tokenizer(
tokenizer_type, tokenizer_path=None, padding_side="right", logger=None
):
if tokenizer_path is None:
tokenizer_path = TOKENIZER_PATH[tokenizer_type]
if logger is not None:
logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
processor = None
if tokenizer_type == "clipL":
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
elif tokenizer_type == "llm":
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, padding_side=padding_side
)
elif tokenizer_type == "llm-i2v":
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, padding_side=padding_side
)
processor = CLIPImageProcessor.from_pretrained(tokenizer_path)
else:
raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
return tokenizer, tokenizer_path, processor
@dataclass
class TextEncoderModelOutput(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
List of decoded texts.
"""
hidden_state: torch.FloatTensor = None
attention_mask: Optional[torch.LongTensor] = None
hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
text_outputs: Optional[list] = None
class TextEncoder(nn.Module):
def __init__(
self,
text_encoder_type: str,
max_length: int,
text_encoder_precision: Optional[str] = None,
text_encoder_path: Optional[str] = None,
tokenizer_type: Optional[str] = None,
tokenizer_path: Optional[str] = None,
output_key: Optional[str] = None,
use_attention_mask: bool = True,
i2v_mode: bool = False,
input_max_length: Optional[int] = None,
prompt_template: Optional[dict] = None,
prompt_template_video: Optional[dict] = None,
hidden_state_skip_layer: Optional[int] = None,
apply_final_norm: bool = False,
reproduce: bool = False,
logger=None,
device=None,
image_embed_interleave=None,
):
super().__init__()
self.text_encoder_type = text_encoder_type
self.max_length = max_length
self.precision = text_encoder_precision
self.model_path = text_encoder_path
self.tokenizer_type = (
tokenizer_type if tokenizer_type is not None else text_encoder_type
)
self.tokenizer_path = (
tokenizer_path if tokenizer_path is not None else text_encoder_path
)
self.use_attention_mask = use_attention_mask
if prompt_template_video is not None:
assert (
use_attention_mask is True
), "Attention mask is True required when training videos."
self.input_max_length = (
input_max_length if input_max_length is not None else max_length
)
self.prompt_template = prompt_template
self.prompt_template_video = prompt_template_video
self.hidden_state_skip_layer = hidden_state_skip_layer
self.apply_final_norm = apply_final_norm
self.i2v_mode = i2v_mode
self.reproduce = reproduce
self.logger = logger
self.image_embed_interleave = image_embed_interleave
self.use_template = self.prompt_template is not None
if self.use_template:
assert (
isinstance(self.prompt_template, dict)
and "template" in self.prompt_template
), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
assert "{}" in str(self.prompt_template["template"]), (
"`prompt_template['template']` must contain a placeholder `{}` for the input text, "
f"got {self.prompt_template['template']}"
)
self.use_video_template = self.prompt_template_video is not None
if self.use_video_template:
if self.prompt_template_video is not None:
assert (
isinstance(self.prompt_template_video, dict)
and "template" in self.prompt_template_video
), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
assert "{}" in str(self.prompt_template_video["template"]), (
"`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
f"got {self.prompt_template_video['template']}"
)
if "t5" in text_encoder_type:
self.output_key = output_key or "last_hidden_state"
elif "clip" in text_encoder_type:
self.output_key = output_key or "pooler_output"
elif "llm" in text_encoder_type or "glm" in text_encoder_type:
self.output_key = output_key or "last_hidden_state"
else:
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
self.model, self.model_path = load_text_encoder(
text_encoder_type=self.text_encoder_type,
text_encoder_precision=self.precision,
text_encoder_path=self.model_path,
logger=self.logger,
device=device,
)
self.dtype = self.model.dtype
self.device = self.model.device
self.tokenizer, self.tokenizer_path, self.processor = load_tokenizer(
tokenizer_type=self.tokenizer_type,
tokenizer_path=self.tokenizer_path,
padding_side="right",
logger=self.logger,
)
def __repr__(self):
return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
@staticmethod
def apply_text_to_template(text, template, prevent_empty_text=True):
"""
Apply text to template.
Args:
text (str): Input text.
template (str or list): Template string or list of chat conversation.
prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
by adding a space. Defaults to True.
"""
if isinstance(template, str):
# Will send string to tokenizer. Used for llm
return template.format(text)
else:
raise TypeError(f"Unsupported template type: {type(template)}")
def text2tokens(self, text, data_type="image"):
"""
Tokenize the input text.
Args:
text (str or list): Input text.
"""
tokenize_input_type = "str"
if self.use_template:
if data_type == "image":
prompt_template = self.prompt_template["template"]
elif data_type == "video":
prompt_template = self.prompt_template_video["template"]
else:
raise ValueError(f"Unsupported data type: {data_type}")
if isinstance(text, (list, tuple)):
text = [
self.apply_text_to_template(one_text, prompt_template)
for one_text in text
]
if isinstance(text[0], list):
tokenize_input_type = "list"
elif isinstance(text, str):
text = self.apply_text_to_template(text, prompt_template)
if isinstance(text, list):
tokenize_input_type = "list"
else:
raise TypeError(f"Unsupported text type: {type(text)}")
kwargs = dict(
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
if tokenize_input_type == "str":
return self.tokenizer(
text,
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
**kwargs,
)
elif tokenize_input_type == "list":
return self.tokenizer.apply_chat_template(
text,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
**kwargs,
)
else:
raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
def encode(
self,
batch_encoding,
use_attention_mask=None,
output_hidden_states=False,
do_sample=None,
hidden_state_skip_layer=None,
return_texts=False,
data_type="image",
semantic_images=None,
device=None,
):
"""
Args:
batch_encoding (dict): Batch encoding from tokenizer.
use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
Defaults to None.
output_hidden_states (bool): Whether to output hidden states. If False, return the value of
self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
output_hidden_states will be set True. Defaults to False.
do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
When self.produce is False, do_sample is set to True by default.
hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
If None, self.output_key will be used. Defaults to None.
hidden_state_skip_layer (PIL.Image): The reference images for i2v models.
return_texts (bool): Whether to return the decoded texts. Defaults to False.
"""
device = self.model.device if device is None else device
use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
hidden_state_skip_layer = use_default(
hidden_state_skip_layer, self.hidden_state_skip_layer
)
do_sample = use_default(do_sample, not self.reproduce)
if not self.i2v_mode:
attention_mask = (
batch_encoding["attention_mask"].to(device)
if use_attention_mask
else None
)
outputs = self.model(
input_ids=batch_encoding["input_ids"].to(device),
attention_mask=attention_mask,
output_hidden_states=output_hidden_states
or hidden_state_skip_layer is not None,
)
if hidden_state_skip_layer is not None:
last_hidden_state = outputs.hidden_states[
-(hidden_state_skip_layer + 1)
]
# Real last hidden state already has layer norm applied. So here we only apply it
# for intermediate layers.
if hidden_state_skip_layer > 0 and self.apply_final_norm:
last_hidden_state = self.model.final_layer_norm(last_hidden_state)
else:
last_hidden_state = outputs[self.output_key]
# Remove hidden states of instruction tokens, only keep prompt tokens.
if self.use_template:
if data_type == "image":
crop_start = self.prompt_template.get("crop_start", -1)
elif data_type == "video":
crop_start = self.prompt_template_video.get("crop_start", -1)
else:
raise ValueError(f"Unsupported data type: {data_type}")
if crop_start > 0:
last_hidden_state = last_hidden_state[:, crop_start:]
attention_mask = (
attention_mask[:, crop_start:] if use_attention_mask else None
)
if output_hidden_states:
return TextEncoderModelOutput(
last_hidden_state, attention_mask, outputs.hidden_states
)
return TextEncoderModelOutput(last_hidden_state, attention_mask)
else:
image_outputs = self.processor(semantic_images, return_tensors="pt")[
"pixel_values"
].to(device)
attention_mask = (
batch_encoding["attention_mask"].to(device)
if use_attention_mask
else None
)
outputs = self.model(
input_ids=batch_encoding["input_ids"].to(device),
attention_mask=attention_mask,
output_hidden_states=output_hidden_states
or hidden_state_skip_layer is not None,
pixel_values=image_outputs,
)
if hidden_state_skip_layer is not None:
last_hidden_state = outputs.hidden_states[
-(hidden_state_skip_layer + 1)
]
# Real last hidden state already has layer norm applied. So here we only apply it
# for intermediate layers.
if hidden_state_skip_layer > 0 and self.apply_final_norm:
last_hidden_state = self.model.final_layer_norm(last_hidden_state)
else:
last_hidden_state = outputs[self.output_key]
if self.use_template:
if data_type == "video":
crop_start = self.prompt_template_video.get("crop_start", -1)
text_crop_start = (
crop_start
- 1
+ self.prompt_template_video.get("image_emb_len", 576)
)
image_crop_start = self.prompt_template_video.get(
"image_emb_start", 5
)
image_crop_end = self.prompt_template_video.get(
"image_emb_end", 581
)
batch_indices, last_double_return_token_indices = torch.where(
batch_encoding["input_ids"]
== self.prompt_template_video.get("double_return_token_id", 271)
)
if last_double_return_token_indices.shape[0] == 3:
# in case the prompt is too long
last_double_return_token_indices = torch.cat(
(
last_double_return_token_indices,
torch.tensor([batch_encoding["input_ids"].shape[-1]]).to(
device=last_double_return_token_indices.device),
)
)
last_double_return_token_indices = (
last_double_return_token_indices.reshape(
batch_encoding["input_ids"].shape[0], -1
)[:, -1]
)
assistant_crop_start = (
last_double_return_token_indices
- 1
+ self.prompt_template_video.get("image_emb_len", 576)
- 4
)
assistant_crop_end = (
last_double_return_token_indices
- 1
+ self.prompt_template_video.get("image_emb_len", 576)
)
attention_mask_assistant_crop_start = (
last_double_return_token_indices - 4
)
attention_mask_assistant_crop_end = last_double_return_token_indices
else:
raise ValueError(f"Unsupported data type: {data_type}")
text_last_hidden_state = []
text_attention_mask = []
image_last_hidden_state = []
image_attention_mask = []
for i in range(batch_encoding["input_ids"].shape[0]):
text_last_hidden_state.append(
torch.cat(
[
last_hidden_state[
i, text_crop_start : assistant_crop_start[i].item()
],
last_hidden_state[i, assistant_crop_end[i].item() :],
]
)
)
text_attention_mask.append(
torch.cat(
[
attention_mask[
i,
crop_start : attention_mask_assistant_crop_start[
i
].item(),
],
attention_mask[
i, attention_mask_assistant_crop_end[i].item() :
],
]
)
if use_attention_mask
else None
)
image_last_hidden_state.append(
last_hidden_state[i, image_crop_start:image_crop_end]
)
image_attention_mask.append(
torch.ones(image_last_hidden_state[-1].shape[0])
.to(last_hidden_state.device)
.to(attention_mask.dtype)
if use_attention_mask
else None
)
text_last_hidden_state = torch.stack(text_last_hidden_state)
text_attention_mask = torch.stack(text_attention_mask)
image_last_hidden_state = torch.stack(image_last_hidden_state)
image_attention_mask = torch.stack(image_attention_mask)
if semantic_images is not None and 0 < self.image_embed_interleave < 6:
image_last_hidden_state = image_last_hidden_state[
:, ::self.image_embed_interleave, :
]
image_attention_mask = image_attention_mask[
:, ::self.image_embed_interleave
]
assert (
text_last_hidden_state.shape[0] == text_attention_mask.shape[0]
and image_last_hidden_state.shape[0]
== image_attention_mask.shape[0]
)
last_hidden_state = torch.cat(
[image_last_hidden_state, text_last_hidden_state], dim=1
)
attention_mask = torch.cat(
[image_attention_mask, text_attention_mask], dim=1
)
if output_hidden_states:
return TextEncoderModelOutput(
last_hidden_state,
attention_mask,
hidden_states_list=outputs.hidden_states,
)
return TextEncoderModelOutput(last_hidden_state, attention_mask)
def forward(
self,
text,
use_attention_mask=None,
output_hidden_states=False,
do_sample=False,
hidden_state_skip_layer=None,
return_texts=False,
):
batch_encoding = self.text2tokens(text)
return self.encode(
batch_encoding,
use_attention_mask=use_attention_mask,
output_hidden_states=output_hidden_states,
do_sample=do_sample,
hidden_state_skip_layer=hidden_state_skip_layer,
return_texts=return_texts,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment