Commit ae089db4 authored by GoatWu's avatar GoatWu
Browse files

Merge branch 'main' of github.com:ModelTC/lightx2v into dev-debug-distill

parents 8b213df0 4796fc6e
# comfyui部署
# ComfyUI 部署
xxx
即将提供该功能
......@@ -25,9 +25,8 @@ git clone https://github.com/ModelTC/lightx2v.git lightx2v && cd lightx2v
conda create -n lightx2v python=3.11 && conda activate lightx2v
pip install -r requirements.txt
# 单独重新安装transformers,避免pip的冲突检查
# 混元模型需要在4.45.2版本的transformers下运行,如果不需要跑混元模型,可以忽略
pip install transformers==4.45.2
# pip install transformers==4.45.2
# 安装 flash-attention 2
git clone https://github.com/Dao-AILab/flash-attention.git --recursive
......@@ -41,7 +40,7 @@ cd flash-attention/hopper && python setup.py install
```shell
# 修改脚本中的路径
bash scripts/run_wan_t2v.sh
bash scripts/wan/run_wan_t2v.sh
```
除了脚本中已有的输入参数,`--config_json`指向的`${lightx2v_path}/configs/wan_t2v.json`中也会存在一些必要的参数,可以根据需要,自行修改。
除了脚本中已有的输入参数,`--config_json`指向的`wan_t2v.json`中也会存在一些必要的参数,可以根据需要,自行修改。
......@@ -2,17 +2,33 @@
==================
.. figure:: ../../../assets/img_lightx2v.png
:width: 100%
:width: 80%
:align: center
:alt: Lightx2v
:class: no-scaled-link
.. raw:: html
<p style="text-align:center">
<strong>一个轻量级的视频生成推理框架
</strong>
<div align="center" style="font-family: charter;">
<a href="https://opensource.org/licenses/Apache-2.0"><img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg" alt="License"></a>
<a href="https://deepwiki.com/ModelTC/lightx2v"><img src="https://deepwiki.com/badge.svg" alt="Ask DeepWiki"></a>
<a href="https://lightx2v-en.readthedocs.io/en/latest"><img src="https://img.shields.io/badge/docs-English-99cc2" alt="Doc"></a>
<a href="https://lightx2v-zhcn.readthedocs.io/zh-cn/latest"><img src="https://img.shields.io/badge/文档-中文-99cc2" alt="Doc"></a>
<a href="https://hub.docker.com/r/lightx2v/lightx2v/tags"><img src="https://badgen.net/badge/icon/docker?icon=docker&label" alt="Docker"></a>
</div>
<div align="center" style="font-family: charter;">
<strong>LightX2V: 一个轻量级的视频生成推理框架</strong>
</div>
LightX2V 是一个轻量级的视频生成推理框架,旨在提供一个利用多种先进的视频生成推理技术的推理工具。该框架作为统一的推理平台,支持不同模型的文本到视频(T2V)和图像到视频(I2V)等生成任务。X2V 表示将不同的输入模态(X,如文本或图像)转换(to)为视频输出(V)。
GitHub: https://github.com/ModelTC/lightx2v
HuggingFace: https://huggingface.co/lightx2v
文档列表
-------------
......@@ -22,6 +38,7 @@
:caption: 快速入门
快速入门 <getting_started/quickstart.md>
基准测试 <getting_started/benchmark.md>
.. toctree::
:maxdepth: 1
......@@ -32,6 +49,8 @@
注意力机制 <method_tutorials/attention.md>
参数卸载 <method_tutorials/offload.md>
并行推理 <method_tutorials/parallel.md>
步数蒸馏 <method_tutorials/step_distill.md>
自回归蒸馏 <method_tutorials/autoregressive_distill.md>
.. toctree::
:maxdepth: 1
......@@ -39,14 +58,8 @@
低延迟场景部署 <deploy_guides/for_low_latency.md>
低资源场景部署 <deploy_guides/for_low_resource.md>
Lora模型部署 <deploy_guides/lora_deploy.md>
服务化部署 <deploy_guides/deploy_service.md>
gradio部署 <deploy_guides/deploy_gradio.md>
comfyui部署 <deploy_guides/deploy_comfyui.md>
Gradio部署 <deploy_guides/deploy_gradio.md>
ComfyUI部署 <deploy_guides/deploy_comfyui.md>
本地windows电脑部署 <deploy_guides/deploy_local_windows.md>
.. Indices and tables
.. ==================
.. * :ref:`genindex`
.. * :ref:`modindex`
# 注意力机制
# 🎯 DiT 模型中的注意力类型配置说明
xxx
当前 DiT 模型在 `LightX2V` 中三个地方使用到了注意力,每个注意力可以分别配置底层注意力库类型。
---
## 使用注意力的位置
1. **图像的自注意力(Self-Attention)**
- 配置参数:`self_attn_1_type`
2. **图像与提示词(Text)之间的交叉注意力(Cross-Attention)**
- 配置参数:`cross_attn_1_type`
3. **I2V 模式下图像与参考图(Reference)之间的交叉注意力**
- 配置参数:`cross_attn_2_type`
---
## 🚀 支持的注意力库(Backend)
| 名称 | 类型名称 | GitHub 链接 |
|--------------------|------------------|-------------|
| Flash Attention 2 | `flash_attn2` | [flash-attention v2](https://github.com/Dao-AILab/flash-attention) |
| Flash Attention 3 | `flash_attn3` | [flash-attention v3](https://github.com/Dao-AILab/flash-attention) |
| Sage Attention 2 | `sage_attn2` | [SageAttention](https://github.com/thu-ml/SageAttention) |
| Radial Attention | `radial_attn` | [Radial Attention](https://github.com/mit-han-lab/radial-attention) |
| Sparge Attention | `sparge_ckpt` | [Sparge Attention](https://github.com/thu-ml/SpargeAttn) |
---
## 🛠️ 配置示例
`wan_i2v.json` 配置文件中,可以通过如下方式指定使用的注意力类型:
```json
{
"self_attn_1_type": "radial_attn",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3"
}
```
如需更换为其他类型,只需将对应值替换为上述表格中的类型名称即可。
tips: radial_attn因为稀疏算法原理的限制只能用在self attention
---
对于 Sparge Attention 配置参考 `wan_t2v_sparge.json` 文件:
Sparge Attention是需要后一个训练的权重
```json
{
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3"
"sparge": true,
"sparge_ckpt": "/path/to/sparge_wan2.1_t2v_1.3B.pt"
}
```
---
如需进一步定制注意力机制的行为,请参考各注意力库的官方文档或实现代码。
# 特征缓存
xxx
## 缓存加速算法
- 在扩散模型的推理过程中,缓存复用是一种重要的加速算法。
- 其核心思想是在部分时间步跳过冗余计算,通过复用历史缓存结果提升推理效率。
- 算法的关键在于如何决策在哪些时间步进行缓存复用,通常基于模型状态变化或误差阈值动态判断。
- 在推理过程中,需要缓存如中间特征、残差、注意力输出等关键内容。当进入可复用时间步时,直接利用已缓存的内容,通过泰勒展开等近似方法重构当前输出,从而减少重复计算,实现高效推理。
### TeaCache
`TeaCache`的核心思想是通过对相邻时间步输入的**相对L1**距离进行累加,当累计距离达到设定阈值时,判定当前时间步可以进行缓存复用。
- 具体来说,算法在每一步推理时计算当前输入与上一步输入的相对L1距离,并将其累加。
- 当累计距离超过阈值,说明模型状态发生了足够的变化,则直接复用最近一次缓存的内容,跳过部分冗余计算。这样可以显著减少模型的前向计算次数,提高推理速度。
实际效果上,TeaCache 在保证生成质量的前提下,实现了明显的加速。加速前后的视频对比如下:  
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:58s | 单卡H200推理耗时:17.9s |
| ![加速前效果](../../../../assets/gifs/1.gif) | ![加速后效果](../../../../assets/gifs/2.gif) |
- 加速比为:**3.24**
- config:[wan_t2v_1_3b_tea_480p.json](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json)
- 参考论文:[https://arxiv.org/abs/2411.19108](https://arxiv.org/abs/2411.19108)
### TaylorSeer Cache
`TaylorSeer Cache`的核心在于利用泰勒公式对缓存内容进行再次计算,作为缓存复用时间步的残差补偿。具体做法是在缓存复用的时间步,不仅简单地复用历史缓存,还通过泰勒展开对当前输出进行近似重构。这样可以在减少计算量的同时,进一步提升输出的准确性。泰勒展开能够有效捕捉模型状态的微小变化,使得缓存复用带来的误差得到补偿,从而在加速的同时保证生成质量。`TaylorSeer Cache`适用于对输出精度要求较高的场景,能够在缓存复用的基础上进一步提升模型推理的表现。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:57.7s | 单卡H200推理耗时:41.3s |
| ![加速前效果](../../../../assets/gifs/3.gif) | ![加速后效果](../../../../assets/gifs/4.gif) |
- 加速比为:**1.39**
- config:[wan_t2v_taylorseer](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/taylorseer/wan_t2v_taylorseer.json)
- 参考论文:[https://arxiv.org/abs/2503.06923](https://arxiv.org/abs/2503.06923)
### AdaCache
`AdaCache`的核心思想是根据指定block块中的部分缓存内容,动态调整缓存复用的步长。
- 算法会分析相邻两个时间步在特定 block 内的特征差异,根据差异大小自适应地决定下一个缓存复用的时间步间隔。
- 当模型状态变化较小时,步长自动加大,减少缓存更新频率;当状态变化较大时,步长缩小,保证输出质量。
这样可以根据实际推理过程中的动态变化,灵活调整缓存策略,实现更高效的加速和更优的生成效果。AdaCache 适合对推理速度和生成质量都有较高要求的应用场景。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:227s | 单卡H200推理耗时:83s |
| ![加速前效果](../../../../assets/gifs/5.gif) | ![加速后效果](../../../../assets/gifs/6.gif) |
- 加速比为:**2.73**
- config:[wan_i2v_ada](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/adacache/wan_i2v_ada.json)
- 参考论文:[https://arxiv.org/abs/2411.02397](https://arxiv.org/abs/2411.02397)
### CustomCache
`CustomCache`综合了`TeaCache``TaylorSeer Cache`的优势。
- 它结合了`TeaCache`在缓存决策上的实时性和合理性,通过动态阈值判断何时进行缓存复用.
- 同时利用`TaylorSeer`的泰勒展开方法对已缓存内容进行利用。
这样不仅能够高效地决定缓存复用的时机,还能最大程度地利用缓存内容,提升输出的准确性和生成质量。实际测试表明,`CustomCache`在多个内容生成任务上,生成的视频质量优于单独使用`TeaCache、TaylorSeer Cache``AdaCache`的方案,是目前综合性能最优的缓存加速算法之一。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:57.9s | 单卡H200推理耗时:16.6s |
| ![加速前效果](../../../../assets/gifs/7.gif) | ![加速后效果](../../../../assets/gifs/8.gif) |
- 加速比为:**3.49**
- config:[wan_t2v_custom_1_3b](https://github.com/ModelTC/lightx2v/tree/main/configs/caching/custom/wan_t2v_custom_1_3b.json)
## 使用方式
特征缓存的config文件在[这里](https://github.com/ModelTC/lightx2v/tree/main/configs/caching)
通过指定--config_json到具体的config文件,即可以测试不同的cache算法
[这里](https://github.com/ModelTC/lightx2v/tree/main/scripts/cache)有一些运行脚本供使用。
......@@ -2,6 +2,10 @@ import torch
try:
import flashinfer
from packaging import version
flashinfer_version = version.parse(flashinfer.__version__)
has_o_dtype = flashinfer_version >= version.parse("0.2.6.post1")
except ImportError:
flashinfer = None
......@@ -29,7 +33,8 @@ def radial_attn(
indptr = get_indptr_from_mask(mask, query)
indices = get_indices_from_mask(mask, query)
bsr_wrapper.plan(
kwargs = dict(
indptr=indptr,
indices=indices,
M=seqlen,
......@@ -43,6 +48,10 @@ def radial_attn(
kv_data_type=key.dtype,
use_fp16_qk_reduction=True,
)
if has_o_dtype:
kwargs["o_data_type"] = query.dtype
bsr_wrapper.plan(**kwargs)
o = bsr_wrapper.run(query, key, value)
......
......@@ -121,8 +121,9 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
except Exception as e:
logger.error(f"Disk worker thread error: {e}")
def _async_prefetch_block(self, weights):
next_block_idx = self.pin_memory_buffer.get_max_block_index()
def _async_prefetch_block(self, blocks, next_block_idx=None):
if next_block_idx is None:
next_block_idx = self.pin_memory_buffer.get_max_block_index()
if next_block_idx < 0:
next_block_idx = 0
......@@ -137,7 +138,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
with self.task_lock:
self.pending_tasks[obj_key] = True
phase = weights.blocks[next_block_idx].compute_phases[phase_idx]
phase = blocks[next_block_idx].compute_phases[phase_idx]
priority_key = (next_block_idx, phase_idx)
self.disk_task_queue.put((priority_key, (next_block_idx, phase_idx, phase)))
......@@ -149,20 +150,20 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
with self.task_lock:
self.pending_tasks[obj_key] = True
block = weights.blocks[next_block_idx]
block = blocks[next_block_idx]
self.disk_task_queue.put((obj_key, (next_block_idx, block)))
def _sync_prefetch_block(self, weights):
def _sync_prefetch_block(self, blocks):
block_idx = 0
while not self.pin_memory_buffer.is_nearly_full():
if self.offload_gra == "phase":
for phase_idx in range(self.phases_num):
phase = weights.blocks[block_idx].compute_phases[phase_idx]
phase = blocks[block_idx].compute_phases[phase_idx]
logger.info(f"Synchronous loading: block={block_idx}, phase={phase_idx}")
phase.load_from_disk()
self.pin_memory_buffer.push((block_idx, phase_idx), phase)
else:
block = weights.blocks[block_idx]
block = blocks[block_idx]
logger.info(f"Synchronous loading: block={block_idx}")
for phase in block.compute_phases:
phase.load_from_disk()
......@@ -170,11 +171,11 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
block_idx += 1
def prefetch_weights_from_disk(self, weights):
def prefetch_weights_from_disk(self, blocks):
if self.initial_prefetch_done:
return
self._sync_prefetch_block(weights)
self._sync_prefetch_block(blocks)
self.initial_prefetch_done = True
def prefetch_weights(self, block_idx, blocks):
......@@ -193,7 +194,15 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}")
else:
logger.info("Not find prefetch block={block_idx} task. This is a bug.")
logger.info("Not find prefetch block={block_idx} task.")
logger.info("Sync prefetch block={block_idx}.")
self._async_prefetch_block(blocks, block_idx)
start_time = time.time()
for phase_idx in self.phases_num:
while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
time.sleep(0.001)
if time.time() - start_time > 15:
raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
with torch.cuda.stream(self.cuda_load_stream):
block = self.pin_memory_buffer.get(obj_key)
......@@ -224,7 +233,14 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
else:
logger.info("Not find prefetch block={block_idx}, phase={phase_idx} task. This is a bug.")
logger.info(f"Not find block={block_idx}, phase={phase_idx} task.")
logger.info(f"Sync prefetch block={block_idx}, phase={phase_idx}.")
self._async_prefetch_block(blocks, block_idx)
start_time = time.time()
while not self.pin_memory_buffer.exists((block_idx, phase_idx)):
time.sleep(0.001)
if time.time() - start_time > 5:
raise TimeoutError(f"Load timeout: block={block_idx}, phase={phase_idx}")
with torch.cuda.stream(self.cuda_load_stream):
phase = self.pin_memory_buffer.get(obj_key)
......
......@@ -2,14 +2,9 @@ import torch
import torch.nn as nn
from vllm import _custom_ops as ops
try:
import q8_kernels.functional as Q8F
except ImportError:
Q8F = None
class QuantLinearInt8(nn.Module):
def __init__(self, in_features, out_features, bias=True):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
......@@ -18,7 +13,7 @@ class QuantLinearInt8(nn.Module):
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32))
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)
......@@ -44,18 +39,31 @@ class QuantLinearInt8(nn.Module):
)
return output_tensor.unsqueeze(0)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
class QuantLinearFp8(nn.Module):
def __init__(self, in_features, out_features, bias=True):
def __init__(self, in_features, out_features, bias=True, dtype=torch.bfloat16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32))
if bias:
self.register_buffer("bias", torch.empty(out_features, dtype=torch.float32))
self.register_buffer("bias", torch.empty(out_features, dtype=dtype))
else:
self.register_buffer("bias", None)
......@@ -65,7 +73,6 @@ class QuantLinearFp8(nn.Module):
def forward(self, input_tensor):
input_tensor = input_tensor.squeeze(0)
self.weight = self.weight.to(torch.float8_e4m3fn)
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
......@@ -79,4 +86,19 @@ class QuantLinearFp8(nn.Module):
self.weight_scale.float(),
self.bias,
)
return output_tensor.unsqueeze(0)
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def maybe_cast(t):
if t is not None and t.device != fn(t).device:
return fn(t)
return t
self.weight = maybe_cast(self.weight)
self.weight_scale = maybe_cast(self.weight_scale)
self.bias = maybe_cast(self.bias)
return self
......@@ -51,11 +51,11 @@ class GELU(nn.Module):
class T5LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
def __init__(self, dim, eps=1e-6, dtype=torch.float16):
super(T5LayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.weight = nn.Parameter(torch.ones(dim, dtype=dtype))
def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)
......@@ -65,7 +65,7 @@ class T5LayerNorm(nn.Module):
class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1, quantized=False, quant_scheme=None):
def __init__(self, dim, dim_attn, num_heads, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
assert dim_attn % num_heads == 0
super(T5Attention, self).__init__()
self.dim = dim
......@@ -82,10 +82,10 @@ class T5Attention(nn.Module):
linear_cls = nn.Linear
# layers
self.q = linear_cls(dim, dim_attn, bias=False)
self.k = linear_cls(dim, dim_attn, bias=False)
self.v = linear_cls(dim, dim_attn, bias=False)
self.o = linear_cls(dim_attn, dim, bias=False)
self.q = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
self.k = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
self.v = linear_cls(dim, dim_attn, bias=False, dtype=dtype)
self.o = linear_cls(dim_attn, dim, bias=False, dtype=dtype)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context=None, mask=None, pos_bias=None):
......@@ -125,7 +125,7 @@ class T5Attention(nn.Module):
class T5FeedForward(nn.Module):
def __init__(self, dim, dim_ffn, dropout=0.1, quantized=False, quant_scheme=None):
def __init__(self, dim, dim_ffn, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
super(T5FeedForward, self).__init__()
self.dim = dim
self.dim_ffn = dim_ffn
......@@ -138,9 +138,9 @@ class T5FeedForward(nn.Module):
else:
linear_cls = nn.Linear
# layers
self.gate = nn.Sequential(linear_cls(dim, dim_ffn, bias=False), GELU())
self.fc1 = linear_cls(dim, dim_ffn, bias=False)
self.fc2 = linear_cls(dim_ffn, dim, bias=False)
self.gate = nn.Sequential(linear_cls(dim, dim_ffn, bias=False, dtype=dtype), GELU())
self.fc1 = linear_cls(dim, dim_ffn, bias=False, dtype=dtype)
self.fc2 = linear_cls(dim_ffn, dim, bias=False, dtype=dtype)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
......@@ -152,7 +152,7 @@ class T5FeedForward(nn.Module):
class T5SelfAttention(nn.Module):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1, quantized=False, quant_scheme=None):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1, quantized=False, quant_scheme=None, dtype=torch.bfloat16):
super(T5SelfAttention, self).__init__()
self.dim = dim
self.dim_attn = dim_attn
......@@ -162,11 +162,11 @@ class T5SelfAttention(nn.Module):
self.shared_pos = shared_pos
# layers
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout, quantized, quant_scheme)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn, dropout, quantized, quant_scheme)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
self.norm1 = T5LayerNorm(dim, dtype=dtype)
self.attn = T5Attention(dim, dim_attn, num_heads, dropout, quantized, quant_scheme, dtype)
self.norm2 = T5LayerNorm(dim, dtype=dtype)
self.ffn = T5FeedForward(dim, dim_ffn, dropout, quantized, quant_scheme, dtype=dtype)
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype)
def forward(self, x, mask=None, pos_bias=None):
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
......@@ -212,7 +212,7 @@ class T5CrossAttention(nn.Module):
class T5RelativeEmbedding(nn.Module):
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
def __init__(self, num_buckets, num_heads, bidirectional, dtype=torch.bfloat16, max_dist=128):
super(T5RelativeEmbedding, self).__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
......@@ -220,7 +220,7 @@ class T5RelativeEmbedding(nn.Module):
self.max_dist = max_dist
# layers
self.embedding = nn.Embedding(num_buckets, num_heads)
self.embedding = nn.Embedding(num_buckets, num_heads, dtype=dtype)
def forward(self, lq, lk):
device = self.embedding.weight.device
......@@ -252,7 +252,7 @@ class T5RelativeEmbedding(nn.Module):
class T5Encoder(nn.Module):
def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1, cpu_offload=False, quantized=False, quant_scheme=None):
def __init__(self, dtype, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1, cpu_offload=False, quantized=False, quant_scheme=None):
super(T5Encoder, self).__init__()
self.cpu_offload = cpu_offload
......@@ -266,11 +266,11 @@ class T5Encoder(nn.Module):
self.quant_scheme = quant_scheme
# layers
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
self.token_embedding = vocab.to(dtype) if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim, dtype=dtype)
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, dtype=dtype) if shared_pos else None
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme) for _ in range(num_layers)])
self.norm = T5LayerNorm(dim)
self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, quantized, quant_scheme, dtype) for _ in range(num_layers)])
self.norm = T5LayerNorm(dim, dtype=dtype)
# initialize weights
# self.apply(init_weights)
......@@ -443,10 +443,10 @@ def _t5(
# init model
with torch.device(device):
model = model_cls(**kwargs)
model = model_cls(dtype=dtype, **kwargs)
# set device
model = model.to(dtype=dtype, device=device)
model = model.to(device=device)
return model
......@@ -511,9 +511,10 @@ class T5EncoderModel:
.requires_grad_(False)
)
logger.info(f"Loading weights from {self.checkpoint_path}")
logger.info(f"Start Loading weights from {self.checkpoint_path}")
model.load_state_dict(torch.load(self.checkpoint_path, map_location="cpu", weights_only=True))
logger.info(f"End Loading weights from {self.checkpoint_path}")
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)
......
......@@ -50,7 +50,7 @@ class LayerNorm(nn.LayerNorm):
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0, quantized=False, quant_scheme=None):
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0, quantized=False, quant_scheme=None, dtype=None):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
......@@ -69,8 +69,8 @@ class SelfAttention(nn.Module):
else:
linear_cls = nn.Linear
self.to_qkv = linear_cls(dim, dim * 3)
self.proj = linear_cls(dim, dim)
self.to_qkv = linear_cls(dim, dim * 3, dtype=dtype)
self.proj = linear_cls(dim, dim, dtype=dtype)
def forward(self, x):
"""
......@@ -108,7 +108,21 @@ class SwiGLU(nn.Module):
class AttentionBlock(nn.Module):
def __init__(self, dim, mlp_ratio, num_heads, post_norm=False, causal=False, activation="quick_gelu", attn_dropout=0.0, proj_dropout=0.0, norm_eps=1e-5, quantized=False, quant_scheme=None):
def __init__(
self,
dim,
mlp_ratio,
num_heads,
post_norm=False,
causal=False,
activation="quick_gelu",
attn_dropout=0.0,
proj_dropout=0.0,
norm_eps=1e-5,
quantized=False,
quant_scheme=None,
dtype=torch.float16,
):
assert activation in ["quick_gelu", "gelu", "swi_glu"]
super().__init__()
self.dim = dim
......@@ -127,13 +141,18 @@ class AttentionBlock(nn.Module):
else:
linear_cls = nn.Linear
self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout, quantized, quant_scheme)
self.norm2 = LayerNorm(dim, eps=norm_eps)
self.norm1 = LayerNorm(dim, eps=norm_eps, dtype=dtype)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout, quantized, quant_scheme, dtype)
self.norm2 = LayerNorm(dim, eps=norm_eps, dtype=dtype)
if activation == "swi_glu":
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
self.mlp = SwiGLU(dim, int(dim * mlp_ratio), dtype=dtype)
else:
self.mlp = nn.Sequential(linear_cls(dim, int(dim * mlp_ratio)), QuickGELU() if activation == "quick_gelu" else nn.GELU(), linear_cls(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
self.mlp = nn.Sequential(
linear_cls(dim, int(dim * mlp_ratio), dtype=dtype),
QuickGELU() if activation == "quick_gelu" else nn.GELU(),
linear_cls(int(dim * mlp_ratio), dim, dtype=dtype),
nn.Dropout(proj_dropout),
)
def forward(self, x):
if self.post_norm:
......@@ -146,7 +165,7 @@ class AttentionBlock(nn.Module):
class AttentionPool(nn.Module):
def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5):
def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5, dtype=torch.float16):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
......@@ -159,11 +178,13 @@ class AttentionPool(nn.Module):
# layers
gain = 1.0 / math.sqrt(dim)
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.to_q = nn.Linear(dim, dim)
self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential(nn.Linear(dim, int(dim * mlp_ratio)), QuickGELU() if activation == "quick_gelu" else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
self.to_q = nn.Linear(dim, dim, dtype=dtype)
self.to_kv = nn.Linear(dim, dim * 2, dtype=dtype)
self.proj = nn.Linear(dim, dim, dtype=dtype)
self.norm = LayerNorm(dim, eps=norm_eps, dtype=dtype)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio), dtype=dtype), QuickGELU() if activation == "quick_gelu" else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim, dtype=dtype), nn.Dropout(proj_dropout)
)
def forward(self, x):
"""
......@@ -191,6 +212,7 @@ class AttentionPool(nn.Module):
class VisionTransformer(nn.Module):
def __init__(
self,
dtype=torch.float16,
image_size=224,
patch_size=16,
dim=768,
......@@ -228,26 +250,26 @@ class VisionTransformer(nn.Module):
# embeddings
gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm)
self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm, dtype=dtype)
if pool_type in ("token", "token_fc"):
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim))
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim, dtype=dtype))
self.pos_embedding = nn.Parameter(gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim, dtype=dtype))
self.dropout = nn.Dropout(embedding_dropout)
# transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.pre_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype) if pre_norm else None
self.transformer = nn.Sequential(
*[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps, quantized, quant_scheme) for _ in range(num_layers)]
*[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps, quantized, quant_scheme, dtype) for _ in range(num_layers)]
)
self.post_norm = LayerNorm(dim, eps=norm_eps)
self.post_norm = LayerNorm(dim, eps=norm_eps, dtype=dtype)
# head
if pool_type == "token":
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
self.head = nn.Parameter(gain * torch.randn(dim, out_dim, dtype=dtype))
elif pool_type == "token_fc":
self.head = nn.Linear(dim, out_dim)
self.head = nn.Linear(dim, out_dim, dtype=dtype)
elif pool_type == "attn_pool":
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps)
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps, dtype=dtype)
def forward(self, x, interpolation=False, use_31_block=False):
b = x.size(0)
......@@ -276,6 +298,7 @@ class VisionTransformer(nn.Module):
class XLMRobertaCLIP(nn.Module):
def __init__(
self,
dtype=torch.float16,
embed_dim=1024,
image_size=224,
patch_size=14,
......@@ -317,6 +340,7 @@ class XLMRobertaCLIP(nn.Module):
# models
self.visual = VisionTransformer(
dtype=dtype,
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
......@@ -341,12 +365,11 @@ class XLMRobertaCLIP(nn.Module):
def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding="eos", dtype=torch.float32, device="cpu", **kwargs):
# init a model on device
with torch.device(device):
model = model_cls(**kwargs)
model = model_cls(dtype=dtype, **kwargs)
# set device
model = model.to(dtype=dtype, device=device)
output = (model,)
model = model.to(device=device)
output = (model,)
# init transforms
if return_transforms:
# mean and std
......@@ -395,20 +418,20 @@ class CLIPModel:
else:
self.checkpoint_path = checkpoint_path
logger.info(f"Loading weights from {self.checkpoint_path}")
# init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device, quantized=self.quantized, quant_scheme=quant_scheme
)
self.model = self.model.eval().requires_grad_(False)
weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)
keys = list(weight_dict.keys())
for key in keys:
if "textual" in key:
weight_dict.pop(key)
logger.info(f"Start Loading weights from {self.checkpoint_path}")
self.model.load_state_dict(weight_dict)
logger.info(f"End Loading weights from {self.checkpoint_path}")
def visual(self, videos, args):
if args.cpu_offload:
......
import flash_attn
try:
import flash_attn
except ModuleNotFoundError:
flash_attn = None
import math
import torch
import torch.nn as nn
......
......@@ -104,7 +104,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x
def _infer_with_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
self.weights_stream_mgr.prefetch_weights_from_disk(weights)
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(self.blocks_num):
if block_idx == 0:
......@@ -132,7 +132,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if block_idx == self.blocks_num - 1:
self.weights_stream_mgr.pin_memory_buffer.pop_front()
self.weights_stream_mgr._async_prefetch_block(weights)
self.weights_stream_mgr._async_prefetch_block(weights.blocks)
if self.clean_cuda_cache:
del grid_sizes, embed, embed0, seq_lens, freqs, context
......@@ -189,7 +189,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
self.weights_stream_mgr.prefetch_weights_from_disk(weights)
self.weights_stream_mgr.prefetch_weights_from_disk(weights.blocks)
for block_idx in range(weights.blocks_num):
for phase_idx in range(self.weights_stream_mgr.phases_num):
......@@ -236,7 +236,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr.swap_phases()
self.weights_stream_mgr._async_prefetch_block(weights)
self.weights_stream_mgr._async_prefetch_block(weights.blocks)
if self.clean_cuda_cache:
del attn_out, y_out, y
......
......@@ -45,6 +45,11 @@ class ApiServer:
self.app.include_router(self.files_router)
self.app.include_router(self.service_router)
def _write_file_sync(self, file_path: Path, content: bytes) -> None:
"""同步写入文件到指定路径"""
with open(file_path, "wb") as buffer:
buffer.write(content)
def _stream_file_response(self, file_path: Path, filename: str | None = None) -> StreamingResponse:
"""Common file streaming response method"""
assert self.file_service is not None, "File service is not initialized"
......@@ -130,32 +135,30 @@ class ApiServer:
video_duration: int = Form(default=5),
):
"""Create video generation task via form"""
# Process uploaded image file
image_path = ""
assert self.file_service is not None, "File service is not initialized"
if image_file and image_file.filename:
file_extension = Path(image_file.filename).suffix
async def save_file_async(file: UploadFile, target_dir: Path) -> str:
"""异步保存文件到指定目录"""
if not file or not file.filename:
return ""
file_extension = Path(file.filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
image_path = self.file_service.input_image_dir / unique_filename
file_path = target_dir / unique_filename
with open(image_path, "wb") as buffer:
content = await image_file.read()
buffer.write(content)
content = await file.read()
image_path = str(image_path)
await asyncio.to_thread(self._write_file_sync, file_path, content)
audio_path = ""
if audio_file and audio_file.filename:
file_extension = Path(audio_file.filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
audio_path = self.file_service.input_audio_dir / unique_filename
return str(file_path)
with open(audio_path, "wb") as buffer:
content = await audio_file.read()
buffer.write(content)
image_path = ""
if image_file and image_file.filename:
image_path = await save_file_async(image_file, self.file_service.input_image_dir)
audio_path = str(audio_path)
audio_path = ""
if audio_file and audio_file.filename:
audio_path = await save_file_async(audio_file, self.file_service.input_audio_dir)
message = TaskRequest(
prompt=prompt,
......@@ -276,6 +279,12 @@ class ApiServer:
"""Get service status"""
return ServiceStatus.get_status_service()
@self.service_router.get("/metadata", response_model=dict)
async def get_service_metadata():
"""Get service metadata"""
assert self.inference_service is not None, "Inference service is not initialized"
return self.inference_service.server_metadata()
def _process_video_generation(self, message: TaskRequest, stop_event: threading.Event):
assert self.video_service is not None, "Video service is not initialized"
try:
......
......@@ -186,6 +186,7 @@ class DistributedInferenceService:
self.is_running = False
def start_distributed_inference(self, args) -> bool:
self.args = args
if self.is_running:
logger.warning("Distributed inference service is already running")
return True
......@@ -311,6 +312,10 @@ class DistributedInferenceService:
return None
def server_metadata(self):
assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first."
return {"nproc_per_node": self.args.nproc_per_node, "model_cls": self.args.model_cls, "model_path": self.args.model_path}
class VideoGenerationService:
def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
......
# Cache
## 缓存加速算法
- 在扩散模型的推理过程中,缓存复用是一种重要的加速算法。
- 其核心思想是在部分时间步跳过冗余计算,通过复用历史缓存结果提升推理效率。
- 算法的关键在于如何决策在哪些时间步进行缓存复用,通常基于模型状态变化或误差阈值动态判断。
- 在推理过程中,需要缓存如中间特征、残差、注意力输出等关键内容。当进入可复用时间步时,直接利用已缓存的内容,通过泰勒展开等近似方法重构当前输出,从而减少重复计算,实现高效推理。
# Feature Caching
## TeaCache
`TeaCache`的核心思想是通过对相邻时间步输入的**相对L1**距离进行累加,当累计距离达到设定阈值时,判定当前时间步可以进行缓存复用。
- 具体来说,算法在每一步推理时计算当前输入与上一步输入的相对L1距离,并将其累加。
- 当累计距离超过阈值,说明模型状态发生了足够的变化,则直接复用最近一次缓存的内容,跳过部分冗余计算。这样可以显著减少模型的前向计算次数,提高推理速度。
The config files for feature caching are available [here](https://github.com/ModelTC/lightx2v/tree/main/configs/caching)
实际效果上,TeaCache 在保证生成质量的前提下,实现了明显的加速。加速前后的视频对比如下:  
By specifying --config_json to the specific config file, you can test different cache algorithms.
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:58s | 单卡H200推理耗时:17.9s |
| ![加速前效果](../../assets/gifs/1.gif) | ![加速后效果](../../assets/gifs/2.gif) |
- 加速比为:**3.24**
- 参考论文:[https://arxiv.org/abs/2411.19108](https://arxiv.org/abs/2411.19108)
Please refer our feature caching doc:
## TaylorSeer Cache
`TaylorSeer Cache`的核心在于利用泰勒公式对缓存内容进行再次计算,作为缓存复用时间步的残差补偿。具体做法是在缓存复用的时间步,不仅简单地复用历史缓存,还通过泰勒展开对当前输出进行近似重构。这样可以在减少计算量的同时,进一步提升输出的准确性。泰勒展开能够有效捕捉模型状态的微小变化,使得缓存复用带来的误差得到补偿,从而在加速的同时保证生成质量。`TaylorSeer Cache`适用于对输出精度要求较高的场景,能够在缓存复用的基础上进一步提升模型推理的表现。
[English doc: Feature Caching](https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/cache.html)
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:57.7s | 单卡H200推理耗时:41.3s |
| ![加速前效果](../../assets/gifs/3.gif) | ![加速后效果](../../assets/gifs/4.gif) |
- 加速比为:**1.39**
- 参考论文:[https://arxiv.org/abs/2503.06923](https://arxiv.org/abs/2503.06923)
## AdaCache
`AdaCache`的核心思想是根据指定block块中的部分缓存内容,动态调整缓存复用的步长。
- 算法会分析相邻两个时间步在特定 block 内的特征差异,根据差异大小自适应地决定下一个缓存复用的时间步间隔。
- 当模型状态变化较小时,步长自动加大,减少缓存更新频率;当状态变化较大时,步长缩小,保证输出质量。
这样可以根据实际推理过程中的动态变化,灵活调整缓存策略,实现更高效的加速和更优的生成效果。AdaCache 适合对推理速度和生成质量都有较高要求的应用场景。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:227s | 单卡H200推理耗时:83s |
| ![加速前效果](../../assets/gifs/5.gif) | ![加速后效果](../../assets/gifs/6.gif) |
- 加速比为:**2.73**
- 参考论文:[https://arxiv.org/abs/2411.02397](https://arxiv.org/abs/2411.02397)
## CustomCache
`CustomCache`综合了`TeaCache``TaylorSeer Cache`的优势。
- 它结合了`TeaCache`在缓存决策上的实时性和合理性,通过动态阈值判断何时进行缓存复用.
- 同时利用`TaylorSeer`的泰勒展开方法对已缓存内容进行利用。
这样不仅能够高效地决定缓存复用的时机,还能最大程度地利用缓存内容,提升输出的准确性和生成质量。实际测试表明,`CustomCache`在多个内容生成任务上,生成的视频质量优于单独使用`TeaCache、TaylorSeer Cache``AdaCache`的方案,是目前综合性能最优的缓存加速算法之一。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:57.9s | 单卡H200推理耗时:16.6s |
| ![加速前效果](../../assets/gifs/7.gif) | ![加速后效果](../../assets/gifs/8.gif) |
- 加速比为:**3.49**
[中文文档: 特征缓存](https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/cache.html)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment