Commit 4b8f0d55 authored by GoatWu's avatar GoatWu
Browse files

update docs

parent 9f8e3b37
......@@ -4,10 +4,68 @@ Step distillation is an important optimization technique in LightX2V. By trainin
## 🔍 Technical Principle
Step distillation is implemented through [Self-Forcing](https://github.com/guandeh17/Self-Forcing) technology. Self-Forcing performs step distillation and CFG distillation on 1.3B autoregressive models. LightX2V extends it with a series of enhancements:
### DMD Distillation
The core technology of step distillation is [DMD Distillation](https://arxiv.org/abs/2311.18828). The DMD distillation framework is shown in the following diagram:
<div align="center">
<img alt="DMD Distillation Framework" src="https://raw.githubusercontent.com/ModelTC/LightX2V/main/assets/figs/step_distill/fig_01.png" width="75%">
</div>
The core idea of DMD distillation is to minimize the KL divergence between the output distributions of the distilled model and the original model:
$$
\begin{aligned}
D_{KL}\left(p_{\text{fake}} \; \| \; p_{\text{real}} \right) &= \mathbb{E}{x\sim p\text{fake}}\left(\log\left(\frac{p_\text{fake}(x)}{p_\text{real}(x)}\right)\right)\\
&= \mathbb{E}{\substack{
z \sim \mathcal{N}(0; \mathbf{I}) \\
x = G_\theta(z)
}}-\big(\log~p_\text{real}(x) - \log~p_\text{fake}(x)\big).
\end{aligned}
$$
Since directly computing the probability density is nearly impossible, DMD distillation instead computes the gradient of this KL divergence:
$$
\begin{aligned}
\nabla_\theta D_{KL}
&= \mathbb{E}{\substack{
z \sim \mathcal{N}(0; \mathbf{I}) \\
x = G_\theta(z)
} } \Big[-
\big(
s_\text{real}(x) - s_\text{fake}(x)\big)
\hspace{.5mm} \frac{dG}{d\theta}
\Big],
\end{aligned}
$$
where $s_\text{real}(x) =\nabla_{x} \text{log}~p_\text{real}(x)$ and $s_\text{fake}(x) =\nabla_{x} \text{log}~p_\text{fake}(x)$ are score functions. Score functions can be computed by the model. Therefore, DMD distillation maintains three models in total:
- `real_score`, computes the score of the real distribution; since the real distribution is fixed, DMD distillation uses the original model with fixed weights as its score function;
- `fake_score`, computes the score of the fake distribution; since the fake distribution is constantly updated, DMD distillation initializes it with the original model and fine-tunes it to learn the output distribution of the generator;
- `generator`, the student model, guided by computing the gradient of the KL divergence between `real_score` and `fake_score`.
> References:
> 1. [DMD (One-step Diffusion with Distribution Matching Distillation)](https://arxiv.org/abs/2311.18828)
> 2. [DMD2 (Improved Distribution Matching Distillation for Fast Image Synthesis)](https://arxiv.org/abs/2405.14867)
### Self-Forcing
DMD distillation technology is designed for image generation. The step distillation in LightX2V is implemented based on [Self-Forcing](https://github.com/guandeh17/Self-Forcing) technology. The overall implementation of Self-Forcing is similar to DMD, but following DMD2, it removes the regression loss and uses ODE initialization instead. Additionally, Self-Forcing adds an important optimization for video generation tasks:
Current DMD distillation-based methods struggle to generate videos in one step. Self-Forcing selects one timestep for optimization each time, with the generator computing gradients only at this step. This approach significantly improves Self-Forcing's training speed and enhances the denoising quality at intermediate timesteps, also improving its effectiveness.
> References:
> 1. [Self-Forcing (Self Forcing: Bridging the Train-Test Gap in Autoregressive Video Diffusion)](https://arxiv.org/abs/2506.08009)
### LightX2V
Self-Forcing performs step distillation and CFG distillation on 1.3B autoregressive models. LightX2V extends it with a series of enhancements:
1. **Larger Models**: Supports step distillation training for 14B models;
2. **More Model Types**: Supports standard bidirectional models and I2V model step distillation training;
3. **Better Results**: LightX2V uses high-quality prompts from approximately 50,000 data entries for training;
For detailed implementation, refer to [Self-Forcing-Plus](https://github.com/GoatWu/Self-Forcing-Plus).
......@@ -16,7 +74,7 @@ For detailed implementation, refer to [Self-Forcing-Plus](https://github.com/Goa
- **Inference Acceleration**: Reduces inference steps from 40-50 to 4 steps without CFG, achieving approximately **20-24x** speedup
- **Quality Preservation**: Maintains original video generation quality through distillation techniques
- **Strong Compatibility**: Supports both T2V and I2V tasks
- **Flexible Usage**: Supports loading complete step distillation models or loading step distillation LoRA on top of native models
- **Flexible Usage**: Supports loading complete step distillation models or loading step distillation LoRA on top of native models; compatible with int8/fp8 model quantization
## 🛠️ Configuration Files
......@@ -26,17 +84,23 @@ Multiple configuration options are provided in the [configs/distill/](https://gi
| Configuration File | Purpose | Model Address |
|-------------------|---------|---------------|
| [wan_t2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg.json) | Load T2V 4-step distillation complete model | TODO |
| [wan_i2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg.json) | Load I2V 4-step distillation complete model | TODO |
| [wan_t2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg_lora.json) | Load Wan-T2V model and step distillation LoRA | TODO |
| [wan_i2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg_lora.json) | Load Wan-I2V model and step distillation LoRA | TODO |
| [wan_t2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg.json) | Load T2V 4-step distillation complete model | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v/blob/main/distill_models/distill_model.safetensors) |
| [wan_i2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg.json) | Load I2V 4-step distillation complete model | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/blob/main/distill_models/distill_model.safetensors) |
| [wan_t2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg_lora.json) | Load Wan-T2V model and step distillation LoRA | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v/blob/main/loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors) |
| [wan_i2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg_lora.json) | Load Wan-I2V model and step distillation LoRA | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/blob/main/loras/Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors) |
### Key Configuration Parameters
- Since DMD distillation only trains a few fixed timesteps, we recommend using `LCM Scheduler` for inference. In [WanStepDistillScheduler](https://github.com/ModelTC/LightX2V/blob/main/lightx2v/models/schedulers/wan/step_distill/scheduler.py), `LCM Scheduler` is already fixed in use, requiring no user configuration.
- `infer_steps`, `denoising_step_list` and `sample_shift` are set to parameters matching those during training, and are generally not recommended for user modification.
- `enable_cfg` must be set to `false` (equivalent to setting `sample_guide_scale = 1`), otherwise the video may become completely blurred.
- `lora_configs` supports merging multiple LoRAs with different strengths. When `lora_configs` is not empty, the original `Wan2.1` model is loaded by default. Therefore, when using `lora_config` and wanting to use step distillation, please set the path and strength of the step distillation LoRA.
```json
{
"infer_steps": 4, // Inference steps
"denoising_step_list": [999, 750, 500, 250], // Denoising timestep list
"denoising_step_list": [1000, 750, 500, 250], // Denoising timestep list
"sample_shift": 5, // Scheduler timestep shift
"enable_cfg": false, // Disable CFG for speed improvement
"lora_configs": [ // LoRA weights path (optional)
{
......
......@@ -4,10 +4,68 @@
## 🔍 技术原理
步数蒸馏通过 [Self-Forcing](https://github.com/guandeh17/Self-Forcing) 技术实现。Self-Forcing 针对 1.3B 的自回归模型进行步数蒸馏、CFG蒸馏。LightX2V 在其基础上,进行了一系列扩展:
### DMD 蒸馏
步数蒸馏的核心技术是 [DMD 蒸馏](https://arxiv.org/abs/2311.18828)。DMD 蒸馏的框架如下图所示:
<div align="center">
<img alt="DMD 蒸馏框架" src="https://raw.githubusercontent.com/ModelTC/LightX2V/main/assets/figs/step_distill/fig_01.png" width="75%">
</div>
DMD蒸馏的核心思想是最小化蒸馏模型与原始模型输出分布的 KL 散度:
$$
\begin{aligned}
D_{KL}\left(p_{\text{fake}} \; \| \; p_{\text{real}} \right) &= \mathbb{E}{x\sim p\text{fake}}\left(\log\left(\frac{p_\text{fake}(x)}{p_\text{real}(x)}\right)\right)\\
&= \mathbb{E}{\substack{
z \sim \mathcal{N}(0; \mathbf{I}) \\
x = G_\theta(z)
}}-\big(\log~p_\text{real}(x) - \log~p_\text{fake}(x)\big).
\end{aligned}
$$
由于直接计算概率密度几乎是不可能的,因此 DMD 蒸馏改为计算这个 KL 散度的梯度:
$$
\begin{aligned}
\nabla_\theta D_{KL}
&= \mathbb{E}{\substack{
z \sim \mathcal{N}(0; \mathbf{I}) \\
x = G_\theta(z)
} } \Big[-
\big(
s_\text{real}(x) - s_\text{fake}(x)\big)
\hspace{.5mm} \frac{dG}{d\theta}
\Big],
\end{aligned}
$$
其中 $s_\text{real}(x) =\nabla_{x} \text{log}~p_\text{real}(x)$ 和 $s_\text{fake}(x) =\nabla_{x} \text{log}~p_\text{fake}(x)$ 为得分函数。得分函数可以由模型进行计算。因此,DMD 蒸馏一共维护三个模型:
- `real_score`,计算真实分布的得分;由于真实分布是固定的,因此 DMD 蒸馏使用固定权重的原始模型作为其得分函数;
- `fake_score`,计算伪分布的得分;由于伪分布是不断更新的,因此 DMD 蒸馏使用原始模型对其初始化,并对其进行微调以学习生成器的输出分布;
- `generator`,学生模型,通过计算 `real_score``fake_score` KL 散度的梯度指导其优化方向。
> 参考文献:
> 1. [DMD (One-step Diffusion with Distribution Matching Distillation)](https://arxiv.org/abs/2311.18828)
> 2. [DMD2 (Improved Distribution Matching Distillation for Fast Image Synthesis)](https://arxiv.org/abs/2405.14867)
### Self-Forcing
DMD 蒸馏技术是针对图像生成的。Lightx2v 中的步数蒸馏基于 [Self-Forcing](https://github.com/guandeh17/Self-Forcing) 技术实现。Self-Forcing 的整体实现与 DMD 类似,但是仿照 DMD2,去掉了它的回归损失,而是使用了 ODE 初始化。此外,Self-Forcing 针对视频生成任务加入了一个重要优化:
目前基于 DMD 蒸馏的方法难以一步生成视频。Self-Forcing 每次选择一个时间步进行优化,generator 仅仅在这一步计算梯度。这种方法使得 Self-Forcing 的训练速度显著提升,并且提升了中间时间步的去噪质量,其效果亦有所提升。
> 参考文献:
> 1. [Self-Forcing (Self Forcing: Bridging the Train-Test Gap in Autoregressive Video Diffusion)](https://arxiv.org/abs/2506.08009)
### Lightx2v
Self-Forcing 针对 1.3B 的自回归模型进行步数蒸馏、CFG蒸馏。LightX2V 在其基础上,进行了一系列扩展:
1. **更大的模型**:支持 14B 模型的步数蒸馏训练;
2. **更多的模型**:支持标准的双向模型,以及 I2V 模型的步数蒸馏训练;
3. **更好的效果**:Lightx2v 使用了约 50,000 条数据的高质量 prompt 进行训练;
具体实现可参考 [Self-Forcing-Plus](https://github.com/GoatWu/Self-Forcing-Plus)
......@@ -16,7 +74,7 @@
- **推理加速**:推理步数从 40-50 步减少到 4 步且无需 CFG,速度提升约 **20-24x**
- **质量保持**:通过蒸馏技术保持原有的视频生成质量
- **兼容性强**:支持 T2V 和 I2V 任务
- **使用灵活**:支持加载完整步数蒸馏模型,或者在原生模型的基础上加载步数蒸馏LoRA
- **使用灵活**:支持加载完整步数蒸馏模型,或者在原生模型的基础上加载步数蒸馏LoRA;支持与 int8/fp8 模型量化相兼容
## 🛠️ 配置文件说明
......@@ -26,17 +84,23 @@
| 配置文件 | 用途 | 模型地址 |
|----------|------|------------|
| [wan_t2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg.json) | 加载 T2V 4步蒸馏完整模型 | TODO |
| [wan_i2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg.json) | 加载 I2V 4步蒸馏完整模型 | TODO |
| [wan_t2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg_lora.json) | 加载 Wan-T2V 模型和步数蒸馏 LoRA | TODO |
| [wan_i2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg_lora.json) | 加载 Wan-I2V 模型和步数蒸馏 LoRA | TODO |
| [wan_t2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg.json) | 加载 T2V 4步蒸馏完整模型 | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v/blob/main/distill_models/distill_model.safetensors) |
| [wan_i2v_distill_4step_cfg.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg.json) | 加载 I2V 4步蒸馏完整模型 | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/blob/main/distill_models/distill_model.safetensors) |
| [wan_t2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg_lora.json) | 加载 Wan-T2V 模型和步数蒸馏 LoRA | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill-Lightx2v/blob/main/loras/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors) |
| [wan_i2v_distill_4step_cfg_lora.json](https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg_lora.json) | 加载 Wan-I2V 模型和步数蒸馏 LoRA | [hugging-face](https://huggingface.co/lightx2v/Wan2.1-I2V-14B-480P-StepDistill-CfgDistill-Lightx2v/blob/main/loras/Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors) |
### 关键配置参数
- 由于 DMD 蒸馏仅训练几个固定的时间步,因此我们推荐使用 `LCM Scheduler` 进行推理。[WanStepDistillScheduler](https://github.com/ModelTC/LightX2V/blob/main/lightx2v/models/schedulers/wan/step_distill/scheduler.py) 中,已经固定使用 `LCM Scheduler`,无需用户进行配置。
- `infer_steps`, `denoising_step_list``sample_shift` 设置为与训练时相匹配的参数,一般不建议用户修改。
- `enable_cfg` 一定设置为 `false`(等价于设置 `sample_guide_scale = 1`),否则可能出现视频完全模糊的现象。
- `lora_configs` 支持融合不同强度的多个 lora。当 `lora_configs` 不为空时,默认加载原始的 `Wan2.1` 模型。因此使用 `lora_config` 并且想要使用步数蒸馏时,请设置步数蒸馏lora的路径与强度。
```json
{
"infer_steps": 4, // 推理步数
"denoising_step_list": [999, 750, 500, 250], // 去噪时间步列表
"denoising_step_list": [1000, 750, 500, 250], // 去噪时间步列表
"sample_shift": 5, // 调度器 timestep shift
"enable_cfg": false, // 关闭CFG以提升速度
"lora_configs": [ // LoRA权重路径(可选)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment