Commit db4238af authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Feat vfi (#158)



* feat: add vfi(rife)

* refactor: update video saving functionality and add VAE to ComfyUI image conversion.

* refactor: enhance video saving process and integrate VAE to ComfyUI image conversion

* refactor: reorganize imports and enhance code structure in default_runner and utils modules

* feat: add video frame interpolation support with RIFE model integration

* refactor: streamline code structure and improve readability in IFNet and refine modules

* fix: style

* feat: add script for downloading and installing RIFE model with flownet.pkl extraction

* update

* docs: update README and documentation to include video frame interpolation feature with RIFE model

* fix ci

* docs: enhance video frame interpolation documentation and update configuration usage for RIFE model

---------
Co-authored-by: default avatarhelloyongyang <yongyang1030@163.com>
parent b48d08ae
......@@ -27,3 +27,4 @@ dist/
.cache/
server_cache/
app/.gradio/
*.pkl
......@@ -63,6 +63,7 @@ For comprehensive usage instructions, please refer to our documentation: **[Engl
- **🔄 Parallel Inference**: Multi-GPU parallel processing for enhanced performance
- **📱 Flexible Deployment Options**: Support for Gradio, service deployment, ComfyUI and other deployment methods
- **🎛️ Dynamic Resolution Inference**: Adaptive resolution adjustment for optimal generation quality
- **🎞️ Video Frame Interpolation**: RIFE-based frame interpolation for smooth frame rate enhancement
## 🏆 Performance Benchmarks
......@@ -80,6 +81,7 @@ For detailed performance metrics and comparisons, please refer to our [benchmark
- [Parameter Offloading](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/offload.html) - Three-tier storage architecture
- [Parallel Inference](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/parallel.html) - Multi-GPU acceleration strategies
- [Step Distillation](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/step_distill.html) - 4-step inference technology
- [Video Frame Interpolation](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/video_frame_interpolation.html) - RIFE-based frame interpolation
### 🛠️ **Deployment Guides**
- [Low-Resource Deployment](https://lightx2v-en.readthedocs.io/en/latest/deploy_guides/for_low_resource.html) - Optimized 8GB VRAM solutions
......
......@@ -61,6 +61,7 @@
- **🔄 并行推理加速**: 多GPU并行处理,显著提升性能表现
- **📱 灵活部署选择**: 支持Gradio、服务化部署、ComfyUI等多种部署方式
- **🎛️ 动态分辨率推理**: 自适应分辨率调整,优化生成质量
- **🎞️ 视频帧插值**: 基于RIFE的帧插值技术,实现流畅的帧率提升
## 🏆 性能基准测试
......@@ -78,6 +79,7 @@
- [参数卸载](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/offload.html) - 三级存储架构
- [并行推理](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/parallel.html) - 多GPU加速策略
- [步数蒸馏](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/step_distill.html) - 4步推理技术
- [视频帧插值](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/video_frame_interpolation.html) - 基于RIFE的帧插值技术
### 🛠️ **部署指南**
- [低资源场景部署](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/for_low_resource.html) - 优化的8GB显存解决方案
......
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"video_frame_interpolation": {
"algo": "rife",
"target_fps": 24,
"model_path": "/path to flownet.pkl"
}
}
......@@ -51,6 +51,7 @@ Documentation
Changing Resolution Inference <method_tutorials/changing_resolution.md>
Step Distill <method_tutorials/step_distill.md>
Autoregressive Distill <method_tutorials/autoregressive_distill.md>
Video Frame Interpolation <method_tutorials/video_frame_interpolation.md>
.. toctree::
:maxdepth: 1
......
# Video Frame Interpolation (VFI)
> **Important Note**: Video frame interpolation is enabled through configuration files, not command-line parameters. Please add a `video_frame_interpolation` configuration block to your JSON config file to enable this feature.
## Overview
Video Frame Interpolation (VFI) is a technique that generates intermediate frames between existing frames to increase the frame rate and create smoother video playback. LightX2V integrates the RIFE (Real-Time Intermediate Flow Estimation) model to provide high-quality frame interpolation capabilities.
## What is RIFE?
RIFE is a state-of-the-art video frame interpolation method that uses optical flow estimation to generate intermediate frames. It can effectively:
- Increase video frame rate (e.g., from 16 FPS to 32 FPS)
- Create smooth motion transitions
- Maintain high visual quality with minimal artifacts
- Process videos in real-time
## Installation and Setup
### Download RIFE Model
First, download the RIFE model weights using the provided script:
```bash
python tools/download_rife.py <target_directory>
```
For example, to download to the location:
```bash
python tools/download_rife.py /path/to/rife/train_log
```
This script will:
- Download RIFEv4.26 model from HuggingFace
- Extract and place the model files in the correct directory
- Clean up temporary files
## Usage
### Configuration File Setup
Video frame interpolation is enabled through configuration files. Add a `video_frame_interpolation` configuration block to your JSON config file:
```json
{
"infer_steps": 50,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"fps": 16,
"video_frame_interpolation": {
"algo": "rife",
"target_fps": 32,
"model_path": "/path/to/rife/train_log"
}
}
```
### Command Line Interface
Run inference using a configuration file that includes VFI settings:
```bash
python lightx2v/infer.py \
--model_cls wan2.1 \
--task t2v \
--model_path /path/to/model \
--config_json ./configs/video_frame_interpolation/wan_t2v.json \
--prompt "A beautiful sunset over the ocean" \
--save_video_path ./output.mp4
```
### Configuration Parameters
In the `video_frame_interpolation` configuration block:
- `algo`: Frame interpolation algorithm, currently supports "rife"
- `target_fps`: Target frame rate for the output video
- `model_path`: RIFE model path, typically "/path/to/rife/train_log"
Other related configurations:
- `fps`: Source video frame rate (default 16)
### Configuration Priority
The system automatically handles video frame rate configuration with the following priority:
1. `video_frame_interpolation.target_fps` - If video frame interpolation is enabled, this frame rate is used as the output frame rate
2. `fps` (default 16) - If video frame interpolation is not enabled, this frame rate is used; it's always used as the source frame rate
## How It Works
### Frame Interpolation Process
1. **Source Video Generation**: The base model generates video frames at the source FPS
2. **Frame Analysis**: RIFE analyzes adjacent frames to estimate optical flow
3. **Intermediate Frame Generation**: New frames are generated between existing frames
4. **Temporal Smoothing**: The interpolated frames create smooth motion transitions
### Technical Details
- **Input Format**: ComfyUI Image tensors [N, H, W, C] in range [0, 1]
- **Output Format**: Interpolated ComfyUI Image tensors [M, H, W, C] in range [0, 1]
- **Processing**: Automatic padding and resolution handling
- **Memory Optimization**: Efficient GPU memory management
## Example Configurations
### Basic Frame Rate Doubling
Create configuration file `wan_t2v_vfi_32fps.json`:
```json
{
"infer_steps": 50,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"seed": 42,
"sample_guide_scale": 6,
"enable_cfg": true,
"fps": 16,
"video_frame_interpolation": {
"algo": "rife",
"target_fps": 32,
"model_path": "/path/to/rife/train_log"
}
}
```
Run command:
```bash
python lightx2v/infer.py \
--model_cls wan2.1 \
--task t2v \
--model_path ./models/wan2.1 \
--config_json ./wan_t2v_vfi_32fps.json \
--prompt "A cat playing in the garden" \
--save_video_path ./output_32fps.mp4
```
### Higher Frame Rate Enhancement
Create configuration file `wan_i2v_vfi_60fps.json`:
```json
{
"infer_steps": 30,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"seed": 42,
"sample_guide_scale": 6,
"enable_cfg": true,
"fps": 16,
"video_frame_interpolation": {
"algo": "rife",
"target_fps": 60,
"model_path": "/path/to/rife/train_log"
}
}
```
Run command:
```bash
python lightx2v/infer.py \
--model_cls wan2.1 \
--task i2v \
--model_path ./models/wan2.1 \
--config_json ./wan_i2v_vfi_60fps.json \
--image_path ./input.jpg \
--prompt "Smooth camera movement" \
--save_video_path ./output_60fps.mp4
```
## Performance Considerations
### Memory Usage
- RIFE processing requires additional GPU memory
- Memory usage scales with video resolution and length
- Consider using lower resolutions for longer videos
### Processing Time
- Frame interpolation adds processing overhead
- Higher target frame rates require more computation
- Processing time is roughly proportional to the number of interpolated frames
### Quality vs Speed Trade-offs
- Higher interpolation ratios may introduce artifacts
- Optimal range: 2x to 4x frame rate increase
- For extreme interpolation (>4x), consider multiple passes
## Best Practices
### Optimal Use Cases
- **Motion-heavy videos**: Benefit most from frame interpolation
- **Camera movements**: Smoother panning and zooming
- **Action sequences**: Reduced motion blur perception
- **Slow-motion effects**: Create fluid slow-motion videos
### Recommended Settings
- **Source FPS**: 16-24 FPS (generated by base model)
- **Target FPS**: 32-60 FPS (2x to 4x increase)
- **Resolution**: Up to 720p for best performance
### Troubleshooting
#### Common Issues
1. **Out of Memory**: Reduce video resolution or target FPS
2. **Artifacts in output**: Lower the interpolation ratio
3. **Slow processing**: Check GPU memory and consider using CPU offloading
#### Solutions
Solve issues by modifying the configuration file:
```json
{
// For memory issues, use lower resolution
"target_height": 480,
"target_width": 832,
// For quality issues, use moderate interpolation
"video_frame_interpolation": {
"target_fps": 24 // instead of 60
},
// For performance issues, enable offloading
"cpu_offload": true
}
```
## Technical Implementation
The RIFE integration in LightX2V includes:
- **RIFEWrapper**: ComfyUI-compatible wrapper for RIFE model
- **Automatic Model Loading**: Seamless integration with the inference pipeline
- **Memory Optimization**: Efficient tensor management and GPU memory usage
- **Quality Preservation**: Maintains original video quality while adding frames
......@@ -52,6 +52,7 @@ HuggingFace: https://huggingface.co/lightx2v
变分辨率推理 <method_tutorials/changing_resolution.md>
步数蒸馏 <method_tutorials/step_distill.md>
自回归蒸馏 <method_tutorials/autoregressive_distill.md>
视频帧插值 <method_tutorials/video_frame_interpolation.md>
.. toctree::
:maxdepth: 1
......
# 视频帧插值 (VFI)
> **重要说明**: 视频帧插值功能通过配置文件启用,而不是通过命令行参数。请在配置 JSON 文件中添加 `video_frame_interpolation` 配置块来启用此功能。
## 概述
视频帧插值(VFI)是一种在现有帧之间生成中间帧的技术,用于提高帧率并创建更流畅的视频播放效果。LightX2V 集成了 RIFE(Real-Time Intermediate Flow Estimation)模型,提供高质量的帧插值能力。
## 什么是 RIFE?
RIFE 是一种最先进的视频帧插值方法,使用光流估计来生成中间帧。它能够有效地:
- 提高视频帧率(例如,从 16 FPS 提升到 32 FPS)
- 创建平滑的运动过渡
- 保持高视觉质量,最少伪影
- 实时处理视频
## 安装和设置
### 下载 RIFE 模型
首先,使用提供的脚本下载 RIFE 模型权重:
```bash
python tools/download_rife.py <目标目录>
```
例如,下载到指定位置:
```bash
python tools/download_rife.py /path/to/rife/train_log
```
此脚本将:
- 从 HuggingFace 下载 RIFEv4.26 模型
- 提取并将模型文件放置在正确的目录中
- 清理临时文件
## 使用方法
### 配置文件设置
视频帧插值功能通过配置文件启用。在你的配置 JSON 文件中添加 `video_frame_interpolation` 配置块:
```json
{
"infer_steps": 50,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"fps": 16,
"video_frame_interpolation": {
"algo": "rife",
"target_fps": 32,
"model_path": "/path/to/rife/train_log"
}
}
```
### 命令行使用
使用包含 VFI 配置的配置文件运行推理:
```bash
python lightx2v/infer.py \
--model_cls wan2.1 \
--task t2v \
--model_path /path/to/model \
--config_json ./configs/video_frame_interpolation/wan_t2v.json \
--prompt "美丽的海上日落" \
--save_video_path ./output.mp4
```
### 配置参数说明
`video_frame_interpolation` 配置块中:
- `algo`: 帧插值算法,目前支持 "rife"
- `target_fps`: 输出视频的目标帧率
- `model_path`: RIFE 模型路径,通常为 "train_log"
其他相关配置:
- `fps`: 源视频帧率(默认 16)
### 配置优先级
系统会自动处理视频帧率配置,优先级如下:
1. `video_frame_interpolation.target_fps` - 如果启用视频帧插值,使用此帧率作为输出帧率
2. `fps`(默认 16)- 如果未启用视频帧插值,使用此帧率;同时总是用作源帧率
**注意**: 系统不再使用 `video_fps` 配置项,统一使用 `video_frame_interpolation.target_fps` 来控制输出视频的帧率。
## 工作原理
### 帧插值过程
1. **源视频生成**: 基础模型以源 FPS 生成视频帧
2. **帧分析**: RIFE 分析相邻帧以估计光流
3. **中间帧生成**: 在现有帧之间生成新帧
4. **时序平滑**: 插值帧创建平滑的运动过渡
### 技术细节
- **输入格式**: ComfyUI 图像张量 [N, H, W, C],范围 [0, 1]
- **输出格式**: 插值后的 ComfyUI 图像张量 [M, H, W, C],范围 [0, 1]
- **处理**: 自动填充和分辨率处理
- **内存优化**: 高效的 GPU 内存管理
## 示例配置
### 基础帧率翻倍
创建配置文件 `wan_t2v_vfi_32fps.json`
```json
{
"infer_steps": 50,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"seed": 42,
"sample_guide_scale": 6,
"enable_cfg": true,
"fps": 16,
"video_frame_interpolation": {
"algo": "rife",
"target_fps": 32,
"model_path": "/path/to/rife/train_log"
}
}
```
运行命令:
```bash
python lightx2v/infer.py \
--model_cls wan2.1 \
--task t2v \
--model_path ./models/wan2.1 \
--config_json ./wan_t2v_vfi_32fps.json \
--prompt "一只小猫在花园里玩耍" \
--save_video_path ./output_32fps.mp4
```
### 更高帧率增强
创建配置文件 `wan_i2v_vfi_60fps.json`
```json
{
"infer_steps": 30,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"seed": 42,
"sample_guide_scale": 6,
"enable_cfg": true,
"fps": 16,
"video_frame_interpolation": {
"algo": "rife",
"target_fps": 60,
"model_path": "/path/to/rife/train_log"
}
}
```
运行命令:
```bash
python lightx2v/infer.py \
--model_cls wan2.1 \
--task i2v \
--model_path ./models/wan2.1 \
--config_json ./wan_i2v_vfi_60fps.json \
--image_path ./input.jpg \
--prompt "平滑的相机运动" \
--save_video_path ./output_60fps.mp4
```
## 性能考虑
### 内存使用
- RIFE 处理需要额外的 GPU 内存
- 内存使用量与视频分辨率和长度成正比
- 对于较长的视频,考虑使用较低的分辨率
### 处理时间
- 帧插值会增加处理开销
- 更高的目标帧率需要更多计算
- 处理时间大致与插值帧数成正比
### 质量与速度权衡
- 更高的插值比率可能引入伪影
- 最佳范围:2x 到 4x 帧率增加
- 对于极端插值(>4x),考虑多次处理
## 最佳实践
### 最佳使用场景
- **运动密集视频**: 从帧插值中受益最多
- **相机运动**: 更平滑的平移和缩放
- **动作序列**: 减少运动模糊感知
- **慢动作效果**: 创建流畅的慢动作视频
### 推荐设置
- **源 FPS**: 16-24 FPS(基础模型生成)
- **目标 FPS**: 32-60 FPS(2x 到 4x 增加)
- **分辨率**: 最高 720p 以获得最佳性能
### 故障排除
#### 常见问题
1. **内存不足**: 减少视频分辨率或目标 FPS
2. **输出中有伪影**: 降低插值比率
3. **处理缓慢**: 检查 GPU 内存并考虑使用 CPU 卸载
#### 解决方案
通过修改配置文件来解决问题:
```json
{
// 内存问题解决:使用较低分辨率
"target_height": 480,
"target_width": 832,
// 质量问题解决:使用适中的插值
"video_frame_interpolation": {
"target_fps": 24 // 而不是 60
},
// 性能问题解决:启用卸载
"cpu_offload": true
}
```
## 技术实现
LightX2V 中的 RIFE 集成包括:
- **RIFEWrapper**: 与 ComfyUI 兼容的 RIFE 模型包装器
- **自动模型加载**: 与推理管道的无缝集成
- **内存优化**: 高效的张量管理和 GPU 内存使用
- **质量保持**: 在添加帧的同时保持原始视频质量
import gc
from PIL import Image
from loguru import logger
import requests
from requests.exceptions import RequestException
import torch
import torch.distributed as dist
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video
from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.envs import *
from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter
from loguru import logger
from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
from .base_runner import BaseRunner
......@@ -48,12 +49,22 @@ class DefaultRunner(BaseRunner):
else:
self.init_device = torch.device("cuda")
def load_vfi_model(self):
if self.config["video_frame_interpolation"].get("algo", None) == "rife":
from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper
logger.info("Loading RIFE model...")
return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
else:
raise ValueError(f"Unsupported VFI model: {self.config['vfi']}")
@ProfilingContext("Load models")
def load_model(self):
self.model = self.load_transformer()
self.text_encoders = self.load_text_encoder()
self.image_encoder = self.load_image_encoder()
self.vae_encoder, self.vae_decoder = self.load_vae()
self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
def check_sub_servers(self, task_type):
urls = self.config.get("sub_servers", {}).get(task_type, [])
......@@ -178,11 +189,6 @@ class DefaultRunner(BaseRunner):
gc.collect()
return images
@ProfilingContext("Save video")
def save_video(self, images):
if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
self.save_video_func(images)
def post_prompt_enhancer(self):
while True:
for url in self.config["sub_servers"]["prompt_enhancer"]:
......@@ -210,9 +216,25 @@ class DefaultRunner(BaseRunner):
latents, generator = self.run_dit()
images = self.run_vae_decoder(latents, generator)
images = vae_to_comfyui_image(images)
if "video_frame_interpolation" in self.config:
assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
images = self.vfi_model.interpolate_frames(
images,
source_fps=self.config.get("fps", 16),
target_fps=target_fps,
)
if save_video:
self.save_video(images)
if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
fps = self.config["video_frame_interpolation"]["target_fps"]
else:
fps = self.config.get("fps", 16)
logger.info(f"Saving video to {self.config.save_video_path}")
save_to_video(images, self.config.save_video_path, fps=fps, method="ffmpeg") # type: ignore
del latents, generator
torch.cuda.empty_cache()
......
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EPE(nn.Module):
def __init__(self):
super(EPE, self).__init__()
def forward(self, flow, gt, loss_mask):
loss_map = (flow - gt.detach()) ** 2
loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5
return loss_map * loss_mask
class Ternary(nn.Module):
def __init__(self):
super(Ternary, self).__init__()
patch_size = 7
out_channels = patch_size * patch_size
self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels))
self.w = np.transpose(self.w, (3, 2, 0, 1))
self.w = torch.tensor(self.w).float().to(device)
def transform(self, img):
patches = F.conv2d(img, self.w, padding=3, bias=None)
transf = patches - img
transf_norm = transf / torch.sqrt(0.81 + transf**2)
return transf_norm
def rgb2gray(self, rgb):
r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray
def hamming(self, t1, t2):
dist = (t1 - t2) ** 2
dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
return dist_norm
def valid_mask(self, t, padding):
n, _, h, w = t.size()
inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
mask = F.pad(inner, [padding] * 4)
return mask
def forward(self, img0, img1):
img0 = self.transform(self.rgb2gray(img0))
img1 = self.transform(self.rgb2gray(img1))
return self.hamming(img0, img1) * self.valid_mask(img0, 1)
class SOBEL(nn.Module):
def __init__(self):
super(SOBEL, self).__init__()
self.kernelX = torch.tensor(
[
[1, 0, -1],
[2, 0, -2],
[1, 0, -1],
]
).float()
self.kernelY = self.kernelX.clone().T
self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device)
self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device)
def forward(self, pred, gt):
N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
img_stack = torch.cat([pred.reshape(N * C, 1, H, W), gt.reshape(N * C, 1, H, W)], 0)
sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1)
sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1)
pred_X, gt_X = sobel_stack_x[: N * C], sobel_stack_x[N * C :]
pred_Y, gt_Y = sobel_stack_y[: N * C], sobel_stack_y[N * C :]
L1X, L1Y = torch.abs(pred_X - gt_X), torch.abs(pred_Y - gt_Y)
loss = L1X + L1Y
return loss
class MeanShift(nn.Conv2d):
def __init__(self, data_mean, data_std, data_range=1, norm=True):
c = len(data_mean)
super(MeanShift, self).__init__(c, c, kernel_size=1)
std = torch.Tensor(data_std)
self.weight.data = torch.eye(c).view(c, c, 1, 1)
if norm:
self.weight.data.div_(std.view(c, 1, 1, 1))
self.bias.data = -1 * data_range * torch.Tensor(data_mean)
self.bias.data.div_(std)
else:
self.weight.data.mul_(std.view(c, 1, 1, 1))
self.bias.data = data_range * torch.Tensor(data_mean)
self.requires_grad = False
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, rank=0):
super(VGGPerceptualLoss, self).__init__()
blocks = []
pretrained = True
self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
for param in self.parameters():
param.requires_grad = False
def forward(self, X, Y, indices=None):
X = self.normalize(X)
Y = self.normalize(Y)
indices = [2, 7, 12, 21, 30]
weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10 / 1.5]
k = 0
loss = 0
for i in range(indices[-1]):
X = self.vgg_pretrained_features[i](X)
Y = self.vgg_pretrained_features[i](Y)
if (i + 1) in indices:
loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1
k += 1
return loss
if __name__ == "__main__":
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(0, 1, (3, 3, 256, 256))).float().to(device)
ternary_loss = Ternary()
print(ternary_loss(img0, img1).shape)
import torch
import torch.nn.functional as F
from math import exp
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
return gauss / gauss.sum()
def create_window(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0).to(device)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window
def create_window_3d(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t())
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().to(device)
return window
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if val_range is None:
if torch.max(img1) > 128:
max_val = 255
else:
max_val = 1
if torch.min(img1) < -0.5:
min_val = -1
else:
min_val = 0
L = max_val - min_val
else:
L = val_range
padd = 0
(_, channel, height, width) = img1.size()
if window is None:
real_size = min(window_size, height, width)
window = create_window(real_size, channel=channel).to(img1.device)
# mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
# mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
mu1 = F.conv2d(F.pad(img1, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel)
mu2 = F.conv2d(F.pad(img2, (5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(F.pad(img1 * img1, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(F.pad(img2 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu2_sq
sigma12 = F.conv2d(F.pad(img1 * img2, (5, 5, 5, 5), "replicate"), window, padding=padd, groups=channel) - mu1_mu2
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret
def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if val_range is None:
if torch.max(img1) > 128:
max_val = 255
else:
max_val = 1
if torch.min(img1) < -0.5:
min_val = -1
else:
min_val = 0
L = max_val - min_val
else:
L = val_range
padd = 0
(_, _, height, width) = img1.size()
if window is None:
real_size = min(window_size, height, width)
window = create_window_3d(real_size, channel=1).to(img1.device)
# Channel is set to 1 since we consider color images as volumetric images
img1 = img1.unsqueeze(1)
img2 = img2.unsqueeze(1)
mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1)
mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_sq
sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu2_sq
sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_mu2
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret
def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False):
device = img1.device
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
levels = weights.size()[0]
mssim = []
mcs = []
for _ in range(levels):
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
mssim.append(sim)
mcs.append(cs)
img1 = F.avg_pool2d(img1, (2, 2))
img2 = F.avg_pool2d(img2, (2, 2))
mssim = torch.stack(mssim)
mcs = torch.stack(mcs)
# Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
if normalize:
mssim = (mssim + 1) / 2
mcs = (mcs + 1) / 2
pow1 = mcs**weights
pow2 = mssim**weights
# From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
output = torch.prod(pow1[:-1] * pow2[-1])
return output
# Classes to re-use window
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, val_range=None):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.val_range = val_range
# Assume 3 channel for SSIM
self.channel = 3
self.window = create_window(window_size, channel=self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
self.window = window
self.channel = channel
_ssim = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
dssim = (1 - _ssim) / 2
return dssim
class MSSSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, channel=3):
super(MSSSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = channel
def forward(self, img1, img2):
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backwarp_tenGrid = {}
def warp(tenInput, tenFlow):
k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device)
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True)
import os
from typing import List, Optional, Tuple
from loguru import logger
import torch
from torch.nn import functional as F
from lightx2v.utils.profiler import ProfilingContext
class RIFEWrapper:
"""Wrapper for RIFE model to work with ComfyUI Image tensors"""
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
def __init__(self, model_path, device: Optional[torch.device] = None):
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Setup torch for optimal performance
torch.set_grad_enabled(False)
if torch.cuda.is_available():
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# Load model
from .train_log.RIFE_HDv3 import Model
self.model = Model()
with ProfilingContext("Load RIFE model"):
self.model.load_model(model_path, -1)
self.model.eval()
self.model.device()
@ProfilingContext("Interpolate frames")
def interpolate_frames(
self,
images: torch.Tensor,
source_fps: float,
target_fps: float,
scale: float = 1.0,
) -> torch.Tensor:
"""
Interpolate frames from source FPS to target FPS
Args:
images: ComfyUI Image tensor [N, H, W, C] in range [0, 1]
source_fps: Source frame rate
target_fps: Target frame rate
scale: Scale factor for processing
Returns:
Interpolated ComfyUI Image tensor [M, H, W, C] in range [0, 1]
"""
# Validate input
assert images.dim() == 4 and images.shape[-1] == 3, "Input must be [N, H, W, C] with C=3"
if source_fps == target_fps:
return images
total_source_frames = images.shape[0]
height, width = images.shape[1:3]
# Calculate padding for model
tmp = max(128, int(128 / scale))
ph = ((height - 1) // tmp + 1) * tmp
pw = ((width - 1) // tmp + 1) * tmp
padding = (0, pw - width, 0, ph - height)
# Calculate target frame positions
frame_positions = self._calculate_target_frame_positions(source_fps, target_fps, total_source_frames)
# Prepare output tensor
output_frames = []
for source_idx1, source_idx2, interp_factor in frame_positions:
if interp_factor == 0.0 or source_idx1 == source_idx2:
# No interpolation needed, use the source frame directly
output_frames.append(images[source_idx1])
else:
# Get frames to interpolate
frame1 = images[source_idx1]
frame2 = images[source_idx2]
# Convert ComfyUI format [H, W, C] to RIFE format [1, C, H, W]
# Also convert from [0, 1] to [0, 1] (already in correct range)
I0 = frame1.permute(2, 0, 1).unsqueeze(0).to(self.device)
I1 = frame2.permute(2, 0, 1).unsqueeze(0).to(self.device)
# Pad images
I0 = F.pad(I0, padding)
I1 = F.pad(I1, padding)
# Perform interpolation
with torch.no_grad():
interpolated = self.model.inference(I0, I1, timestep=interp_factor, scale=scale)
# Convert back to ComfyUI format [H, W, C]
# Crop to original size and permute dimensions
interpolated_frame = interpolated[0, :, :height, :width].permute(1, 2, 0).cpu()
output_frames.append(interpolated_frame)
# Stack all frames
return torch.stack(output_frames, dim=0)
def _calculate_target_frame_positions(self, source_fps: float, target_fps: float, total_source_frames: int) -> List[Tuple[int, int, float]]:
"""
Calculate which frames need to be generated for the target frame rate.
Returns:
List of (source_frame_index1, source_frame_index2, interpolation_factor) tuples
"""
frame_positions = []
# Calculate the time duration of the video
duration = (total_source_frames - 1) / source_fps
# Calculate number of target frames
total_target_frames = int(duration * target_fps) + 1
for target_idx in range(total_target_frames):
# Calculate the time position of this target frame
target_time = target_idx / target_fps
# Calculate the corresponding position in source frames
source_position = target_time * source_fps
# Find the two source frames to interpolate between
source_idx1 = int(source_position)
source_idx2 = min(source_idx1 + 1, total_source_frames - 1)
# Calculate interpolation factor (0 means use frame1, 1 means use frame2)
if source_idx1 == source_idx2:
interpolation_factor = 0.0
else:
interpolation_factor = source_position - source_idx1
frame_positions.append((source_idx1, source_idx2, interpolation_factor))
return frame_positions
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..model.warplayer import warp
# from train_log.refine import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True,
),
nn.LeakyReLU(0.2, True),
)
def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=False,
),
nn.BatchNorm2d(out_planes),
nn.LeakyReLU(0.2, True),
)
class Head(nn.Module):
def __init__(self):
super(Head, self).__init__()
self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1)
self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1)
self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1)
self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1)
self.relu = nn.LeakyReLU(0.2, True)
def forward(self, x, feat=False):
x0 = self.cnn0(x)
x = self.relu(x0)
x1 = self.cnn1(x)
x = self.relu(x1)
x2 = self.cnn2(x)
x = self.relu(x2)
x3 = self.cnn3(x)
if feat:
return [x0, x1, x2, x3]
return x3
class ResConv(nn.Module):
def __init__(self, c, dilation=1):
super(ResConv, self).__init__()
self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1)
self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
self.relu = nn.LeakyReLU(0.2, True)
def forward(self, x):
return self.relu(self.conv(x) * self.beta + x)
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c // 2, 3, 2, 1),
conv(c // 2, c, 3, 2, 1),
)
self.convblock = nn.Sequential(
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
ResConv(c),
)
self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), nn.PixelShuffle(2))
def forward(self, x, flow=None, scale=1):
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
if flow is not None:
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale
x = torch.cat((x, flow), 1)
feat = self.conv0(x)
feat = self.convblock(feat)
tmp = self.lastconv(feat)
tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
flow = tmp[:, :4] * scale
mask = tmp[:, 4:5]
feat = tmp[:, 5:]
return flow, mask, feat
class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(7 + 8, c=192)
self.block1 = IFBlock(8 + 4 + 8 + 8, c=128)
self.block2 = IFBlock(8 + 4 + 8 + 8, c=96)
self.block3 = IFBlock(8 + 4 + 8 + 8, c=64)
self.block4 = IFBlock(8 + 4 + 8 + 8, c=32)
self.encode = Head()
# not used during inference
"""
self.teacher = IFBlock(8+4+8+3+8, c=64)
self.caltime = nn.Sequential(
nn.Conv2d(16+9, 8, 3, 2, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(32, 64, 3, 2, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 64, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 64, 3, 1, 1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(64, 1, 3, 1, 1),
nn.Sigmoid()
)
"""
def forward(
self,
x,
timestep=0.5,
scale_list=[8, 4, 2, 1],
training=False,
fastmode=True,
ensemble=False,
):
if not training:
channel = x.shape[1] // 2
img0 = x[:, :channel]
img1 = x[:, channel:]
if not torch.is_tensor(timestep):
timestep = (x[:, :1].clone() * 0 + 1) * timestep
else:
timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])
f0 = self.encode(img0[:, :3])
f1 = self.encode(img1[:, :3])
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = None
mask = None
loss_cons = 0
block = [self.block0, self.block1, self.block2, self.block3, self.block4]
for i in range(5):
if flow is None:
flow, mask, feat = block[i](
torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1),
None,
scale=scale_list[i],
)
if ensemble:
print("warning: ensemble is not supported since RIFEv4.21")
else:
wf0 = warp(f0, flow[:, :2])
wf1 = warp(f1, flow[:, 2:4])
fd, m0, feat = block[i](
torch.cat(
(
warped_img0[:, :3],
warped_img1[:, :3],
wf0,
wf1,
timestep,
mask,
feat,
),
1,
),
flow,
scale=scale_list[i],
)
if ensemble:
print("warning: ensemble is not supported since RIFEv4.21")
else:
mask = m0
flow = flow + fd
mask_list.append(mask)
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged.append((warped_img0, warped_img1))
mask = torch.sigmoid(mask)
merged[4] = warped_img0 * mask + warped_img1 * (1 - mask)
if not fastmode:
print("contextnet is removed")
"""
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, :3] * 2 - 1
merged[4] = torch.clamp(merged[4] + res, 0, 1)
"""
return flow_list, mask_list[4], merged
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from ..model.loss import *
from .IFNet_HDv3 import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Model:
def __init__(self, local_rank=-1):
self.flownet = IFNet()
self.device()
self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
self.epe = EPE()
self.version = 4.25
# self.vgg = VGGPerceptualLoss().to(device)
self.sobel = SOBEL()
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
def train(self):
self.flownet.train()
def eval(self):
self.flownet.eval()
def device(self):
self.flownet.to(device)
def load_model(self, path, rank=0):
def convert(param):
if rank == -1:
return {k.replace("module.", ""): v for k, v in param.items() if "module." in k}
else:
return param
if rank <= 0:
if torch.cuda.is_available():
self.flownet.load_state_dict(convert(torch.load(path)), False)
else:
self.flownet.load_state_dict(
convert(torch.load(path, map_location="cpu")),
False,
)
def save_model(self, path, rank=0):
if rank == 0:
torch.save(self.flownet.state_dict(), "{}/flownet.pkl".format(path))
def inference(self, img0, img1, timestep=0.5, scale=1.0):
imgs = torch.cat((img0, img1), 1)
scale_list = [16 / scale, 8 / scale, 4 / scale, 2 / scale, 1 / scale]
flow, mask, merged = self.flownet(imgs, timestep, scale_list)
return merged[-1]
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group["lr"] = learning_rate
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if training:
self.train()
else:
self.eval()
scale = [16, 8, 4, 2, 1]
flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
loss_l1 = (merged[-1] - gt).abs().mean()
loss_smooth = self.sobel(flow[-1], flow[-1] * 0).mean()
# loss_vgg = self.vgg(merged[-1], gt)
if training:
self.optimG.zero_grad()
loss_G = loss_l1 + loss_cons + loss_smooth * 0.1
loss_G.backward()
self.optimG.step()
else:
flow_teacher = flow[2]
return merged[-1], {
"mask": mask,
"flow": flow[-1][:, :2],
"loss_l1": loss_l1,
"loss_cons": loss_cons,
"loss_smooth": loss_smooth,
}
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..model.warplayer import warp
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True,
),
nn.LeakyReLU(0.2, True),
)
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True,
),
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(
in_channels=in_planes,
out_channels=out_planes,
kernel_size=4,
stride=2,
padding=1,
bias=True,
),
nn.LeakyReLU(0.2, True),
)
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
super(Conv2, self).__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
c = 16
class Contextnet(nn.Module):
def __init__(self):
super(Contextnet, self).__init__()
self.conv1 = Conv2(3, c)
self.conv2 = Conv2(c, 2 * c)
self.conv3 = Conv2(2 * c, 4 * c)
self.conv4 = Conv2(4 * c, 8 * c)
def forward(self, x, flow):
x = self.conv1(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
class Unet(nn.Module):
def __init__(self):
super(Unet, self).__init__()
self.down0 = Conv2(17, 2 * c)
self.down1 = Conv2(4 * c, 4 * c)
self.down2 = Conv2(8 * c, 8 * c)
self.down3 = Conv2(16 * c, 16 * c)
self.up0 = deconv(32 * c, 8 * c)
self.up1 = deconv(16 * c, 4 * c)
self.up2 = deconv(8 * c, 2 * c)
self.up3 = deconv(4 * c, c)
self.conv = nn.Conv2d(c, 3, 3, 1, 1)
def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1):
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1))
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x)
return torch.sigmoid(x)
import os
import random
import subprocess
from typing import Optional
from einops import rearrange
import imageio
import imageio_ffmpeg as ffmpeg
from loguru import logger
import numpy as np
import torch
import torchvision
import numpy as np
import imageio
import random
import os
def seed_all(seed):
......@@ -50,7 +51,7 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, f
def cache_video(
tensor,
save_file,
save_file: str,
fps=30,
suffix=".mp4",
nrow=8,
......@@ -73,7 +74,7 @@ def cache_video(
for _ in range(retry):
try:
# preprocess
tensor = tensor.clamp(min(value_range), max(value_range))
tensor = tensor.clamp(min(value_range), max(value_range)) # type: ignore
tensor = torch.stack(
[torchvision.utils.make_grid(u, nrow=nrow, normalize=normalize, value_range=value_range) for u in tensor.unbind(2)],
dim=1,
......@@ -94,3 +95,160 @@ def cache_video(
else:
logger.info(f"cache_video failed, error: {error}", flush=True)
return None
def vae_to_comfyui_image(vae_output: torch.Tensor) -> torch.Tensor:
"""
Convert VAE decoder output to ComfyUI Image format
Args:
vae_output: VAE decoder output tensor, typically in range [-1, 1]
Shape: [B, C, T, H, W] or [B, C, H, W]
Returns:
ComfyUI Image tensor in range [0, 1]
Shape: [B, H, W, C] for single frame or [B*T, H, W, C] for video
"""
# Handle video tensor (5D) vs image tensor (4D)
if vae_output.dim() == 5:
# Video tensor: [B, C, T, H, W]
B, C, T, H, W = vae_output.shape
# Reshape to [B*T, C, H, W] for processing
vae_output = vae_output.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
# Normalize from [-1, 1] to [0, 1]
images = (vae_output + 1) / 2
# Clamp values to [0, 1]
images = torch.clamp(images, 0, 1)
# Convert from [B, C, H, W] to [B, H, W, C]
images = images.permute(0, 2, 3, 1).cpu()
return images
def save_to_video(
images: torch.Tensor,
output_path: str,
fps: float = 24.0,
method: str = "imageio",
lossless: bool = False,
output_pix_fmt: Optional[str] = "yuv420p",
) -> None:
"""
Save ComfyUI Image tensor to video file
Args:
images: ComfyUI Image tensor [N, H, W, C] in range [0, 1]
output_path: Path to save the video
fps: Frames per second
method: Save method - "imageio" or "ffmpeg"
lossless: Whether to use lossless encoding (ffmpeg method only)
output_pix_fmt: Pixel format for output (ffmpeg method only)
"""
assert images.dim() == 4 and images.shape[-1] == 3, "Input must be [N, H, W, C] with C=3"
# Ensure output directory exists
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
if method == "imageio":
# Convert to uint8
frames = (images * 255).cpu().numpy().astype(np.uint8)
imageio.mimsave(output_path, frames, fps=fps) # type: ignore
elif method == "ffmpeg":
# Convert to numpy and scale to [0, 255]
frames = (images * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
# Convert RGB to BGR for OpenCV/FFmpeg
frames = frames[..., ::-1].copy()
N, height, width, _ = frames.shape
# Ensure even dimensions for x264
width += width % 2
height += height % 2
# Get ffmpeg executable from imageio_ffmpeg
ffmpeg_exe = ffmpeg.get_ffmpeg_exe()
if lossless:
command = [
ffmpeg_exe,
"-y", # Overwrite output file if it exists
"-f",
"rawvideo",
"-s",
f"{int(width)}x{int(height)}",
"-pix_fmt",
"bgr24",
"-r",
f"{fps}",
"-loglevel",
"error",
"-threads",
"4",
"-i",
"-", # Input from pipe
"-vcodec",
"libx264rgb",
"-crf",
"0",
"-an", # No audio
output_path,
]
else:
command = [
ffmpeg_exe,
"-y", # Overwrite output file if it exists
"-f",
"rawvideo",
"-s",
f"{int(width)}x{int(height)}",
"-pix_fmt",
"bgr24",
"-r",
f"{fps}",
"-loglevel",
"error",
"-threads",
"4",
"-i",
"-", # Input from pipe
"-vcodec",
"libx264",
"-pix_fmt",
output_pix_fmt,
"-an", # No audio
output_path,
]
# Run FFmpeg
process = subprocess.Popen(
command,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if process.stdin is None:
raise BrokenPipeError("No stdin buffer received.")
# Write frames to FFmpeg
for frame in frames:
# Pad frame if needed
if frame.shape[0] < height or frame.shape[1] < width:
padded = np.zeros((height, width, 3), dtype=np.uint8)
padded[: frame.shape[0], : frame.shape[1]] = frame
frame = padded
process.stdin.write(frame.tobytes())
process.stdin.close()
process.wait()
if process.returncode != 0:
error_output = process.stderr.read().decode() if process.stderr else "Unknown error"
raise RuntimeError(f"FFmpeg failed with error: {error_output}")
else:
raise ValueError(f"Unknown save method: {method}")
#!/bin/bash
# Video Frame Interpolation Example Script for WAN T2V
# This script demonstrates how to use RIFE frame interpolation with LightX2V
# VFI is enabled through configuration file, not command line parameters
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16 # remove this can get high quality video
# Run inference with VFI enabled through config file
# The wan_t2v.json config contains video_frame_interpolation settings
python -m lightx2v.infer \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/video_frame_interpolation/wan_t2v.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_video_frame_interpolation.mp4
#!/usr/bin/env python3
# coding: utf-8
import os
import sys
import requests
import zipfile
import shutil
import argparse
from pathlib import Path
def get_base_dir():
"""Get project root directory"""
return Path(__file__).parent.parent
def download_file(url, save_path):
"""Download file"""
print(f"Starting download: {url}")
response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
downloaded_size = 0
with open(save_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
downloaded_size += len(chunk)
if total_size > 0:
progress = (downloaded_size / total_size) * 100
print(f"\rDownload progress: {progress:.1f}%", end="", flush=True)
print(f"\nDownload completed: {save_path}")
def extract_zip(zip_path, extract_to):
"""Extract zip file"""
print(f"Starting extraction: {zip_path}")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_to)
print(f"Extraction completed: {extract_to}")
def find_flownet_pkl(extract_dir):
"""Find flownet.pkl file in extracted directory"""
for root, dirs, files in os.walk(extract_dir):
for file in files:
if file == "flownet.pkl":
return os.path.join(root, file)
return None
def main():
parser = argparse.ArgumentParser(description="Download RIFE model to specified directory")
parser.add_argument("target_directory", help="Target directory path")
args = parser.parse_args()
target_dir = Path(args.target_directory)
if not target_dir.is_absolute():
target_dir = Path.cwd() / target_dir
base_dir = get_base_dir()
temp_dir = base_dir / "_temp"
# Create temporary directory
temp_dir.mkdir(exist_ok=True)
target_dir.mkdir(parents=True, exist_ok=True)
zip_url = "https://huggingface.co/hzwer/RIFE/resolve/main/RIFEv4.26_0921.zip"
zip_path = temp_dir / "RIFEv4.26_0921.zip"
try:
# Download zip file
download_file(zip_url, zip_path)
# Extract file
extract_zip(zip_path, temp_dir)
# Find flownet.pkl file
flownet_pkl = find_flownet_pkl(temp_dir)
if flownet_pkl:
# Copy flownet.pkl to target directory
target_file = target_dir / "flownet.pkl"
shutil.copy2(flownet_pkl, target_file)
print(f"flownet.pkl copied to: {target_file}")
else:
print("Error: flownet.pkl file not found")
return 1
# Clean up temporary files
print("Cleaning up temporary files...")
if zip_path.exists():
zip_path.unlink()
print(f"Deleted: {zip_path}")
# Delete extracted folders
for item in temp_dir.iterdir():
if item.is_dir():
shutil.rmtree(item)
print(f"Deleted directory: {item}")
# Delete the temp directory itself if empty
if temp_dir.exists() and not any(temp_dir.iterdir()):
temp_dir.rmdir()
print(f"Deleted temp directory: {temp_dir}")
print("RIFE model download and installation completed!")
return 0
except Exception as e:
print(f"Error: {e}")
return 1
finally:
if zip_path.exists():
try:
zip_path.unlink()
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
sys.exit(main())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment