Commit b072a45f authored by wangshankun's avatar wangshankun
Browse files

Merge branch 'main' of https://github.com/ModelTC/LightX2V into main

parents c7eb4631 46e17eec
FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel AS base
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8
# use tsinghua source
RUN sed -i 's|http://archive.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list \
&& sed -i 's|http://security.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list
RUN apt-get update && apt-get install -y vim tmux zip unzip wget git build-essential libibverbs-dev ca-certificates \
curl iproute2 ffmpeg libsm6 libxext6 kmod ccache libnuma-dev \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
RUN pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
RUN pip install --no-cache-dir packaging ninja cmake scikit-build-core uv ruff pre-commit -U
RUN git clone https://github.com/vllm-project/vllm.git && cd vllm \
&& python use_existing_torch.py && pip install -r requirements/build.txt \
&& pip install --no-cache-dir --no-build-isolation -v -e .
RUN git clone https://github.com/sgl-project/sglang.git && cd sglang/sgl-kernel \
&& make build && make clean
RUN pip install --no-cache-dir diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio \
imageio-ffmpeg einops loguru qtorch ftfy easydict
RUN git clone https://github.com/Dao-AILab/flash-attention.git --recursive
RUN cd flash-attention && python setup.py install && rm -rf build
RUN cd flash-attention/hopper && python setup.py install && rm -rf build
# RUN git clone https://github.com/thu-ml/SageAttention.git
# # install sageattention with hopper gpu sm9.0
# RUN cd SageAttention && sed -i 's/set()/{"9.0"}/' setup.py && EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 pip install --no-cache-dir -v -e .
WORKDIR /workspace
{
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.900,
"use_image_encoder": false,
"parallel": {
"cfg_p_size": 2
}
}
{
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.900,
"use_image_encoder": false,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses",
"cfg_p_size": 2
}
}
{
"infer_steps": 40,
"target_video_length": 81,
"text_len": 512,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_shift": 5.0,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"boundary": 0.900,
"use_image_encoder": false,
"parallel": {
"seq_p_size": 4,
"seq_p_attn_type": "ulysses"
}
}
......@@ -12,10 +12,10 @@ View all available models: [LightX2V Official Model Repository](https://huggingf
### Standard Directory Structure
Using `Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V` as an example, the standard file structure is as follows:
Using `Wan2.1-I2V-14B-480P-LightX2V` as an example, the standard file structure is as follows:
```
Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V/
Wan2.1-I2V-14B-480P-LightX2V/
├── fp8/ # FP8 quantized version (DIT/T5/CLIP)
│ ├── block_xx.safetensors # DIT model FP8 quantized version
│ ├── models_t5_umt5-xxl-enc-fp8.pth # T5 encoder FP8 quantized version
......@@ -31,12 +31,41 @@ Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V/
│ ├── taew2_1.pth # Lightweight VAE (optional)
│ └── config.json # Model configuration file
├── original/ # Original precision version (DIT/T5/CLIP)
│ ├── xx.safetensors # DIT model original precision version
│ ├── models_t5_umt5-xxl-enc-bf16.pth # T5 encoder original precision version
│ ├── models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth # CLIP encoder original precision version
│ ├── Wan2.1_VAE.pth # VAE variational autoencoder
│ ├── taew2_1.pth # Lightweight VAE (optional)
│ └── config.json # Model configuration file
```
Using `Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V` as an example, the standard file structure is as follows:
```
Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V/
├── distill_fp8/ # FP8 quantized version (DIT/T5/CLIP)
│ ├── block_xx.safetensors # DIT model FP8 quantized version
│ ├── models_t5_umt5-xxl-enc-fp8.pth # T5 encoder FP8 quantized version
│ ├── clip-fp8.pth # CLIP encoder FP8 quantized version
│ ├── Wan2.1_VAE.pth # VAE variational autoencoder
│ ├── taew2_1.pth # Lightweight VAE (optional)
│ └── config.json # Model configuration file
├── distill_int8/ # INT8 quantized version (DIT/T5/CLIP)
│ ├── block_xx.safetensors # DIT model INT8 quantized version
│ ├── models_t5_umt5-xxl-enc-int8.pth # T5 encoder INT8 quantized version
│ ├── clip-int8.pth # CLIP encoder INT8 quantized version
│ ├── Wan2.1_VAE.pth # VAE variational autoencoder
│ ├── taew2_1.pth # Lightweight VAE (optional)
│ └── config.json # Model configuration file
├── distill_models/ # Original precision version (DIT/T5/CLIP)
│ ├── distill_model.safetensors # DIT model original precision version
│ ├── models_t5_umt5-xxl-enc-bf16.pth # T5 encoder original precision version
│ ├── models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth # CLIP encoder original precision version
│ ├── Wan2.1_VAE.pth # VAE variational autoencoder
│ ├── taew2_1.pth # Lightweight VAE (optional)
│ └── config.json # Model configuration file
├── loras/
│ ├── Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors # Distillation model lora
```
### 💾 Storage Recommendations
......@@ -148,24 +177,24 @@ python gradio_demo.py \
# Use Hugging Face CLI to selectively download non-quantized version
huggingface-cli download lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V \
--local-dir ./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V \
--include "original/*"
--include "distill_models/*"
```
```bash
# Use Hugging Face CLI to selectively download FP8 quantized version
huggingface-cli download lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V \
--local-dir ./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V \
--include "fp8/*"
--include "distill_fp8/*"
```
```bash
# Use Hugging Face CLI to selectively download INT8 quantized version
huggingface-cli download lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V \
--local-dir ./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V \
--include "int8/*"
--include "distill_int8/*"
```
> **Important Note**: When starting inference scripts or Gradio, the `model_path` parameter still needs to be specified as the complete path without the `--include` parameter. For example: `model_path=./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V`, not `./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V/int8`.
> **Important Note**: When starting inference scripts or Gradio, the `model_path` parameter still needs to be specified as the complete path without the `--include` parameter. For example: `model_path=./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V`, not `./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V/distill_int8`.
#### 2. Start Inference
......
......@@ -27,14 +27,14 @@ We strongly recommend using the Docker environment, which is the simplest and fa
#### 1. Pull Image
Visit LightX2V's [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags) and select a tag with the latest date, such as `25080104`:
Visit LightX2V's [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags), select a tag with the latest date, such as `25080601-cu128`:
```bash
# Pull the latest version of LightX2V image
docker pull lightx2v/lightx2v:25080104
# Pull the latest version of LightX2V image, this image does not have SageAttention installed
docker pull lightx2v/lightx2v:25080601-cu128
```
If you need to use `SageAttention`, you can use docker image versions with the `-SageSmXX` suffix. The use of `SageAttention` requires selection based on GPU type, where:
If you need to use `SageAttention`, you can use image versions with the `-SageSmXX` suffix. The use of `SageAttention` requires selection based on GPU type, where:
1. A100: -SageSm80
2. RTX30 series: -SageSm86
......@@ -42,13 +42,24 @@ If you need to use `SageAttention`, you can use docker image versions with the `
4. H100: -SageSm90
5. RTX50 series: -SageSm120
For example, to use `SageAttention` on 4090 or H100, the docker image pull command would be:
For example, to use `SageAttention` on 4090 or H100, the image pull commands are:
```bash
# For 4090
docker pull lightx2v/lightx2v:25080104-SageSm89
# For H100
docker pull lightx2v/lightx2v:25080104-SageSm90
# For 4090, with SageAttention installed
docker pull lightx2v/lightx2v:25080601-cu128-SageSm89
# For H100, with SageAttention installed
docker pull lightx2v/lightx2v:25080601-cu128-SageSm90
```
We recommend using the `cuda128` environment for faster inference speed. If you need to use the `cuda124` environment, you can use image versions with the `-cu124` suffix:
```bash
# cuda124 version, without SageAttention installed
docker pull lightx2v/lightx2v:25080601-cu124
# For 4090, cuda124 version, with SageAttention installed
docker pull lightx2v/lightx2v:25080601-cu124-SageSm89
# For H100, cuda124 version, with SageAttention installed
docker pull lightx2v/lightx2v:25080601-cu124-SageSm90
```
#### 2. Run Container
......@@ -59,20 +70,29 @@ docker run --gpus all -itd --ipc=host --name [container_name] -v [mount_settings
#### 3. Domestic Mirror Source (Optional)
For users in mainland China, if the network is unstable when pulling images, you can pull from Aliyun:
For mainland China, if the network is unstable when pulling images, you can pull from Alibaba Cloud:
```bash
# Replace [tag] with the desired image tag to download
# Replace [tag] with the required image tag to download
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:[tag]
# For example, download 25080104
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080104
# For example, download 25080601-cu128
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080601-cu128
# For example, download 25080104-SageSm89
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080104-SageSm89
# For example, download 25080601-cu128-SageSm89
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080601-cu128-SageSm89
# For example, download 25080104-SageSm90
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080104-SageSm90
# For example, download 25080601-cu128-SageSm90
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080601-cu128-SageSm90
# For example, download 25080601-cu124
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080601-cu124
# For example, download 25080601-cu124-SageSm89
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080601-cu124-SageSm89
# For example, download 25080601-cu124-SageSm90
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080601-cu124-SageSm90
```
### 🐍 Conda Environment Setup
......@@ -91,7 +111,7 @@ cd LightX2V
```bash
# Create and activate conda environment
conda create -n lightx2v python=3.12 -y
conda create -n lightx2v python=3.11 -y
conda activate lightx2v
```
......
......@@ -12,10 +12,10 @@
### 标准目录结构
`Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V` 为例,标准文件结构如下:
`Wan2.1-I2V-14B-480P-LightX2V` 为例,标准文件结构如下:
```
Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V/
Wan2.1-I2V-14B-480P-LightX2V/
├── fp8/ # FP8 量化版本 (DIT/T5/CLIP)
│ ├── block_xx.safetensors # DIT 模型 FP8 量化版本
│ ├── models_t5_umt5-xxl-enc-fp8.pth # T5 编码器 FP8 量化版本
......@@ -31,12 +31,42 @@ Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V/
│ ├── taew2_1.pth # 轻量级 VAE (可选)
│ └── config.json # 模型配置文件
├── original/ # 原始精度版本 (DIT/T5/CLIP)
│ ├── xx.safetensors # DIT 模型原始精度版本
│ ├── models_t5_umt5-xxl-enc-bf16.pth # T5 编码器原始精度版本
│ ├── models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth # CLIP 编码器原始精度版本
│ ├── Wan2.1_VAE.pth # VAE 变分自编码器
│ ├── taew2_1.pth # 轻量级 VAE (可选)
│ └── config.json # 模型配置文件
```
`Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V` 为例,标准文件结构如下:
```
Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V/
├── distill_fp8/ # FP8 量化版本 (DIT/T5/CLIP)
│ ├── block_xx.safetensors # DIT 模型 FP8 量化版本
│ ├── models_t5_umt5-xxl-enc-fp8.pth # T5 编码器 FP8 量化版本
│ ├── clip-fp8.pth # CLIP 编码器 FP8 量化版本
│ ├── Wan2.1_VAE.pth # VAE 变分自编码器
│ ├── taew2_1.pth # 轻量级 VAE (可选)
│ └── config.json # 模型配置文件
├── distill_int8/ # INT8 量化版本 (DIT/T5/CLIP)
│ ├── block_xx.safetensors # DIT 模型 INT8 量化版本
│ ├── models_t5_umt5-xxl-enc-int8.pth # T5 编码器 INT8 量化版本
│ ├── clip-int8.pth # CLIP 编码器 INT8 量化版本
│ ├── Wan2.1_VAE.pth # VAE 变分自编码器
│ ├── taew2_1.pth # 轻量级 VAE (可选)
│ └── config.json # 模型配置文件
├── distill_models/ # 原始精度版本 (DIT/T5/CLIP)
│ ├── distill_model.safetensors # DIT 模型原始精度版本
│ ├── models_t5_umt5-xxl-enc-bf16.pth # T5 编码器原始精度版本
│ ├── models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth # CLIP 编码器原始精度版本
│ ├── Wan2.1_VAE.pth # VAE 变分自编码器
│ ├── taew2_1.pth # 轻量级 VAE (可选)
│ └── config.json # 模型配置文件
├── loras/
│ ├── Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors # 蒸馏模型lora
```
### 💾 存储建议
......@@ -148,24 +178,24 @@ python gradio_demo_zh.py \
# 使用 Hugging Face CLI 选择性下载非量化版本
huggingface-cli download lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V \
--local-dir ./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V \
--include "original/*"
--include "distill_models/*"
```
```bash
# 使用 Hugging Face CLI 选择性下载 FP8 量化版本
huggingface-cli download lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V \
--local-dir ./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V \
--include "fp8/*"
--include "distill_fp8/*"
```
```bash
# 使用 Hugging Face CLI 选择性下载 INT8 量化版本
huggingface-cli download lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V \
--local-dir ./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V \
--include "int8/*"
--include "distill_int8/*"
```
> **重要提示**:当启动推理脚本或Gradio时,`model_path` 参数仍需要指定为不包含 `--include` 的完整路径。例如:`model_path=./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V`,而不是 `./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V/int8`。
> **重要提示**:当启动推理脚本或Gradio时,`model_path` 参数仍需要指定为不包含 `--include` 的完整路径。例如:`model_path=./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V`,而不是 `./Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-LightX2V/distill_int8`。
#### 2. 启动推理
......
......@@ -27,11 +27,11 @@
#### 1. 拉取镜像
访问 LightX2V 的 [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags),选择一个最新日期的 tag,比如 `25080104`
访问 LightX2V 的 [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags),选择一个最新日期的 tag,比如 `25080601-cu128`
```bash
# 拉取最新版本的 LightX2V 镜像
docker pull lightx2v/lightx2v:25080104
# 拉取最新版本的 LightX2V 镜像,此镜像没有安装 SageAttention
docker pull lightx2v/lightx2v:25080601-cu128
```
如果需要使用`SageAttention`,可以使用带`-SageSmXX`后缀的镜像版本,`SageAttention`的使用需要针对GPU类型进行选择,其中:
......@@ -45,10 +45,21 @@ docker pull lightx2v/lightx2v:25080104
比如要在4090或者H100上使用`SageAttention`,则拉取镜像命令为:
```bash
# 对于4090
docker pull lightx2v/lightx2v:25080104-SageSm89
# 对于H100
docker pull lightx2v/lightx2v:25080104-SageSm90
# 对于4090,安装了 SageAttention
docker pull lightx2v/lightx2v:25080601-cu128-SageSm89
# 对于H100,安装了 SageAttention
docker pull lightx2v/lightx2v:25080601-cu128-SageSm90
```
我们推荐使用`cuda128`环境,以获得更快的推理速度,若需要使用`cuda124`环境,可以使用带`-cu124`后缀的镜像版本:
```bash
# cuda124版本,没有安装 SageAttention
docker pull lightx2v/lightx2v:25080601-cu124
# 对于4090,cuda124版本,安装了 SageAttention
docker pull lightx2v/lightx2v:25080601-cu124-SageSm89
# 对于H100,cuda124版本,安装了 SageAttention
docker pull lightx2v/lightx2v:25080601-cu124-SageSm90
```
#### 2. 运行容器
......@@ -65,14 +76,23 @@ docker run --gpus all -itd --ipc=host --name [容器名] -v [挂载设置] --ent
# 修改[tag]为所需下载的镜像tag
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:[tag]
# 比如下载 25080104
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080104
# 比如下载 25080601-cu128
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080601-cu128
# 比如下载 25080104-SageSm89
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080104-SageSm89
# 比如下载 25080601-cu128-SageSm89
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080601-cu128-SageSm89
# 比如下载 25080104-SageSm90
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080104-SageSm90
# 比如下载 25080601-cu128-SageSm90
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080601-cu128-SageSm90
# 比如下载 25080601-cu124
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080601-cu124
# 比如下载 25080601-cu124-SageSm89
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080601-cu124-SageSm89
# 比如下载 25080601-cu124-SageSm90
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25080601-cu124-SageSm90
```
### 🐍 Conda 环境搭建
......@@ -91,7 +111,7 @@ cd LightX2V
```bash
# 创建并激活 conda 环境
conda create -n lightx2v python=3.12 -y
conda create -n lightx2v python=3.11 -y
conda activate lightx2v
```
......
import os
import torch
from safetensors import safe_open
from lightx2v.common.ops.attn.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.causvid.transformer_infer import (
......@@ -16,6 +15,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.utils.envs import *
from lightx2v.utils.utils import find_torch_model_path
class WanCausVidModel(WanModel):
......@@ -32,23 +32,12 @@ class WanCausVidModel(WanModel):
self.transformer_infer_class = WanTransformerInferCausVid
def _load_ckpt(self, unified_dtype, sensitive_layer):
ckpt_folder = "causvid_models"
safetensors_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.safetensors")
if os.path.exists(safetensors_path):
with safe_open(safetensors_path, framework="pt") as f:
weight_dict = {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
.pin_memory()
.to(self.device)
for key in f.keys()
}
return weight_dict
ckpt_path = os.path.join(self.model_path, f"{ckpt_folder}/causal_model.pt")
ckpt_path = find_torch_model_path(self.config, self.model_path, "causvid_model.pt")
if os.path.exists(ckpt_path):
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {
key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()
key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key].to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in weight_dict.keys()
}
return weight_dict
......
import glob
import os
import torch
......@@ -10,6 +11,7 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
class WanDistillModel(WanModel):
......@@ -21,15 +23,29 @@ class WanDistillModel(WanModel):
super().__init__(model_path, config, device)
def _load_ckpt(self, unified_dtype, sensitive_layer):
if self.config.get("enable_dynamic_cfg", False):
ckpt_path = os.path.join(self.model_path, "distill_cfg_models", "distill_model.safetensors")
else:
ckpt_path = os.path.join(self.model_path, "distill_models", "distill_model.safetensors")
# For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
ckpt_path = os.path.join(self.model_path, "distill_model.pt")
if os.path.exists(ckpt_path):
logger.info(f"Loading weights from {ckpt_path}")
return self._load_safetensor_to_dict(ckpt_path, unified_dtype, sensitive_layer)
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {
key: (weight_dict[key].to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else weight_dict[key].to(GET_SENSITIVE_DTYPE())).pin_memory().to(self.device)
for key in weight_dict.keys()
}
return weight_dict
if self.config.get("enable_dynamic_cfg", False):
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_distill_ckpt", subdir="distill_cfg_models")
else:
safetensors_path = find_hf_model_path(self.config, self.model_path, "dit_distill_ckpt", subdir="distill_models")
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return super()._load_ckpt(unified_dtype, sensitive_layer)
return weight_dict
class Wan22MoeDistillModel(WanDistillModel, WanModel):
......
......@@ -55,7 +55,7 @@ class WanTransformerDistInfer(WanTransformerInfer):
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f, h, w = grid_sizes[0]
seq_len = f * h * w
freqs_i = torch.cat(
[
......@@ -75,7 +75,7 @@ class WanTransformerDistInfer(WanTransformerInfer):
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f, h, w = grid_sizes[0]
valid_token_length = f * h * w
f = f + 1
seq_len = f * h * w
......
......@@ -85,6 +85,13 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k
def compute_freqs(self, q, grid_sizes, freqs):
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks)
......@@ -109,6 +116,8 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs,
context,
)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
self.weights_stream_mgr.swap_weights()
......@@ -137,6 +146,8 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs,
context,
)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
self.weights_stream_mgr.swap_weights()
......@@ -179,6 +190,8 @@ class WanTransformerInfer(BaseTransformerInfer):
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase:
......@@ -239,6 +252,8 @@ class WanTransformerInfer(BaseTransformerInfer):
elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
......@@ -275,6 +290,14 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0
return freqs_cis
@torch._dynamo.disable
def _apply_audio_dit(self, x, block_idx, grid_sizes, audio_dit_blocks):
for ipa_out in audio_dit_blocks:
if block_idx in ipa_out:
cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num):
x = self.infer_block(
......@@ -287,12 +310,9 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs,
context,
)
if audio_dit_blocks:
x = self._apply_audio_dit(x, block_idx, grid_sizes, audio_dit_blocks)
if audio_dit_blocks is not None and len(audio_dit_blocks) > 0:
for ipa_out in audio_dit_blocks:
if block_idx in ipa_out:
cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
......@@ -328,13 +348,6 @@ class WanTransformerInfer(BaseTransformerInfer):
return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
def compute_freqs(self, q, grid_sizes, freqs):
if "audio" in self.config.get("model_cls", ""):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
def infer_self_attn(self, weights, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa):
if hasattr(weights, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa.squeeze()) * weights.smooth_norm1_weight.tensor
......
......@@ -28,14 +28,8 @@ class WanLoraWrapper:
return lora_name
def _load_lora_file(self, file_path):
use_bfloat16 = GET_DTYPE() == "BF16"
if self.model.config and hasattr(self.model.config, "get"):
use_bfloat16 = self.model.config.get("use_bfloat16", True)
with safe_open(file_path, framework="pt") as f:
if use_bfloat16:
tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()) for key in f.keys()}
else:
tensor_dict = {key: f.get_tensor(key) for key in f.keys()}
tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()) for key in f.keys()}
return tensor_dict
def apply_lora(self, lora_name, alpha=1.0):
......@@ -52,7 +46,7 @@ class WanLoraWrapper:
self.model._init_weights(weight_dict)
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
del lora_weights # 删除节约显存
del lora_weights
return True
@torch.no_grad()
......
......@@ -52,6 +52,8 @@ class WanModel:
if self.dit_quantized:
dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1]
if self.config.model_cls == "wan2.1_distill":
dit_quant_scheme = "distill_" + dit_quant_scheme
if dit_quant_scheme == "gguf":
self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
self.config.use_gguf = True
......
......@@ -113,7 +113,7 @@ class DefaultRunner(BaseRunner):
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
with ProfilingContext4Debug("🚀 infer_main"):
self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"):
......@@ -233,12 +233,13 @@ class DefaultRunner(BaseRunner):
fps = self.config.get("fps", 16)
if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(f"Saving video to {self.config.save_video_path}")
logger.info(f"🎬 Start to save video 🎬")
if self.config["model_cls"] != "wan2.2":
save_to_video(images, self.config.save_video_path, fps=fps, method="ffmpeg") # type: ignore
else:
cache_video(tensor=images, save_file=self.config.save_video_path, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
del latents, generator
torch.cuda.empty_cache()
......
......@@ -9,10 +9,16 @@ class GraphRunner:
self.compile()
def compile(self):
logger.info("start compile...")
logger.info("=" * 60)
logger.info("🚀 Starting Model Compilation - Please wait, this may take a while... 🚀")
logger.info("=" * 60)
with ProfilingContext4Debug("compile"):
self.runner.run_step()
logger.info("end compile...")
logger.info("=" * 60)
logger.info("✅ Model Compilation Completed ✅")
logger.info("=" * 60)
def run_pipeline(self):
return self.runner.run_pipeline()
......@@ -3,7 +3,7 @@ import os
import subprocess
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
......@@ -302,7 +302,7 @@ class VideoGenerator:
return mask.transpose(0, 1)
@torch.no_grad()
def generate_segment(self, inputs: Dict[str, Any], audio_features: torch.Tensor, prev_video: Optional[torch.Tensor] = None, prev_frame_length: int = 5, segment_idx: int = 0) -> torch.Tensor:
def generate_segment(self, inputs, audio_features, prev_video=None, prev_frame_length=5, segment_idx=0, total_steps=None):
"""Generate video segment"""
# Update inputs with audio features
inputs["audio_encoder_output"] = audio_features
......@@ -352,14 +352,15 @@ class VideoGenerator:
inputs["previmg_encoder_output"] = {"prev_latents": prev_latents, "prev_mask": prev_mask}
# Run inference loop
total_steps = self.model.scheduler.infer_steps
if total_steps is None:
total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps):
logger.info(f"==> Segment {segment_idx}, Step {step_index}/{total_steps}")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
with ProfilingContext4Debug("🚀 infer_main"):
self.model.infer(inputs)
with ProfilingContext4Debug("step_post"):
......@@ -695,6 +696,62 @@ class WanAudioRunner(WanRunner): # type:ignore
ret["target_shape"] = self.config.target_shape
return ret
def run_step(self):
"""Optimized pipeline with modular components"""
self.initialize()
assert self._audio_processor is not None
assert self._audio_preprocess is not None
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
with memory_efficient_inference():
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.prepare_inputs()
# Re-initialize scheduler after image encoding sets correct dimensions
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
# Re-create video generator with updated model/scheduler
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
# Process audio
audio_array = self._audio_processor.load_audio(self.config["audio_path"])
video_duration = self.config.get("video_duration", 5)
target_fps = self.config.get("target_fps", 16)
max_num_frames = self.config.get("target_video_length", 81)
audio_len = int(audio_array.shape[0] / self._audio_processor.audio_sr * target_fps)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
# Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, max_num_frames)
self._video_generator.total_segments = len(audio_segments)
# Generate video segments
prev_video = None
torch.manual_seed(self.config.seed)
# Process audio features
audio_features = self._audio_preprocess(audio_segments[0].audio_array, sampling_rate=self._audio_processor.audio_sr, return_tensors="pt").input_values.squeeze(0).to(self.model.device)
# Generate video segment
with memory_efficient_inference():
self._video_generator.generate_segment(
self.inputs.copy(), # Copy to avoid modifying original
audio_features,
prev_video=prev_video,
prev_frame_length=5,
segment_idx=0,
total_steps=1,
)
# Final cleanup
self.end_run()
@RUNNER_REGISTER("wan2.2_moe_audio")
class Wan22MoeAudioRunner(WanAudioRunner):
......
......@@ -89,7 +89,7 @@ class WanCausVidRunner(WanRunner):
self.model.scheduler.latents = self.model.scheduler.last_sample
self.model.scheduler.step_pre(step_index=self.model.scheduler.infer_steps - 1)
with ProfilingContext4Debug("infer"):
with ProfilingContext4Debug("🚀 infer_main"):
self.model.infer(self.inputs, kv_start, kv_end)
kv_start += self.num_frame_per_block * self.frame_seq_length
......@@ -108,7 +108,7 @@ class WanCausVidRunner(WanRunner):
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
with ProfilingContext4Debug("🚀 infer_main"):
self.model.infer(self.inputs, kv_start, kv_end)
with ProfilingContext4Debug("step_post"):
......
......@@ -110,7 +110,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
with ProfilingContext4Debug("🚀 infer_main"):
self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"):
......
......@@ -20,36 +20,22 @@ class _ProfilingContext:
def __enter__(self):
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
if torch.cuda.is_available():
peak_memory = torch.cuda.max_memory_allocated() / (1024**3) # 转换为GB
logger.info(f"[Profile] {self.rank_info} - {self.name} Peak Memory: {peak_memory:.2f} GB")
else:
logger.info(f"[Profile] {self.rank_info} - {self.name} executed without GPU.")
elapsed = time.perf_counter() - self.start_time
logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
return False
async def __aenter__(self):
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self.start_time = time.perf_counter()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
if torch.cuda.is_available():
peak_memory = torch.cuda.max_memory_allocated() / (1024**3) # 转换为GB
logger.info(f"[Profile] {self.rank_info} - {self.name} Peak Memory: {peak_memory:.2f} GB")
else:
logger.info(f"[Profile] {self.rank_info} - {self.name} executed without GPU.")
elapsed = time.perf_counter() - self.start_time
logger.info(f"[Profile] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds")
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment