Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
ae089db4
Commit
ae089db4
authored
Jul 11, 2025
by
GoatWu
Browse files
Merge branch 'main' of github.com:ModelTC/lightx2v into dev-debug-distill
parents
8b213df0
4796fc6e
Changes
50
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
377 additions
and
180 deletions
+377
-180
docs/ZH_CN/source/deploy_guides/deploy_comfyui.md
docs/ZH_CN/source/deploy_guides/deploy_comfyui.md
+2
-2
docs/ZH_CN/source/deploy_guides/deploy_gradio.md
docs/ZH_CN/source/deploy_guides/deploy_gradio.md
+1
-1
docs/ZH_CN/source/deploy_guides/lora_deploy.md
docs/ZH_CN/source/deploy_guides/lora_deploy.md
+3
-0
docs/ZH_CN/source/getting_started/benchmark.md
docs/ZH_CN/source/getting_started/benchmark.md
+3
-0
docs/ZH_CN/source/getting_started/quickstart.md
docs/ZH_CN/source/getting_started/quickstart.md
+3
-4
docs/ZH_CN/source/index.rst
docs/ZH_CN/source/index.rst
+26
-13
docs/ZH_CN/source/method_tutorials/attention.md
docs/ZH_CN/source/method_tutorials/attention.md
+65
-2
docs/ZH_CN/source/method_tutorials/autoregressive_distill.md
docs/ZH_CN/source/method_tutorials/autoregressive_distill.md
+3
-0
docs/ZH_CN/source/method_tutorials/cache.md
docs/ZH_CN/source/method_tutorials/cache.md
+69
-1
docs/ZH_CN/source/method_tutorials/step_distill.md
docs/ZH_CN/source/method_tutorials/step_distill.md
+3
-0
lightx2v/attentions/common/radial_attn.py
lightx2v/attentions/common/radial_attn.py
+10
-1
lightx2v/common/offload/manager.py
lightx2v/common/offload/manager.py
+27
-11
lightx2v/models/input_encoders/hf/q_linear.py
lightx2v/models/input_encoders/hf/q_linear.py
+33
-11
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+29
-28
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+54
-31
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+4
-1
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+4
-4
lightx2v/server/api.py
lightx2v/server/api.py
+27
-18
lightx2v/server/service.py
lightx2v/server/service.py
+5
-0
scripts/cache/readme.md
scripts/cache/readme.md
+6
-52
No files found.
docs/ZH_CN/source/deploy_guides/deploy_comfyui.md
View file @
ae089db4
#
c
omfy
ui
部署
#
C
omfy
UI
部署
xxx
即将提供该功能
docs/ZH_CN/source/deploy_guides/deploy_gradio.md
View file @
ae089db4
#
g
radio部署
#
G
radio
部署
xxx
xxx
docs/ZH_CN/source/deploy_guides/lora_deploy.md
0 → 100644
View file @
ae089db4
# Lora模型部署
xxx
docs/ZH_CN/source/getting_started/benchmark.md
0 → 100644
View file @
ae089db4
# 基准测试
xxx
docs/ZH_CN/source/getting_started/quickstart.md
View file @
ae089db4
...
@@ -25,9 +25,8 @@ git clone https://github.com/ModelTC/lightx2v.git lightx2v && cd lightx2v
...
@@ -25,9 +25,8 @@ git clone https://github.com/ModelTC/lightx2v.git lightx2v && cd lightx2v
conda create
-n
lightx2v
python
=
3.11
&&
conda activate lightx2v
conda create
-n
lightx2v
python
=
3.11
&&
conda activate lightx2v
pip
install
-r
requirements.txt
pip
install
-r
requirements.txt
# 单独重新安装transformers,避免pip的冲突检查
# 混元模型需要在4.45.2版本的transformers下运行,如果不需要跑混元模型,可以忽略
# 混元模型需要在4.45.2版本的transformers下运行,如果不需要跑混元模型,可以忽略
pip
install
transformers
==
4.45.2
#
pip install transformers==4.45.2
# 安装 flash-attention 2
# 安装 flash-attention 2
git clone https://github.com/Dao-AILab/flash-attention.git
--recursive
git clone https://github.com/Dao-AILab/flash-attention.git
--recursive
...
@@ -41,7 +40,7 @@ cd flash-attention/hopper && python setup.py install
...
@@ -41,7 +40,7 @@ cd flash-attention/hopper && python setup.py install
```
shell
```
shell
# 修改脚本中的路径
# 修改脚本中的路径
bash scripts/run_wan_t2v.sh
bash scripts/
wan/
run_wan_t2v.sh
```
```
除了脚本中已有的输入参数,
`--config_json`
指向的
`
${lightx2v_path}/configs/
wan_t2v.json`
中也会存在一些必要的参数,可以根据需要,自行修改。
除了脚本中已有的输入参数,
`--config_json`
指向的
`wan_t2v.json`
中也会存在一些必要的参数,可以根据需要,自行修改。
docs/ZH_CN/source/index.rst
View file @
ae089db4
...
@@ -2,17 +2,33 @@
...
@@ -2,17 +2,33 @@
==================
==================
.. figure:: ../../../assets/img_lightx2v.png
.. figure:: ../../../assets/img_lightx2v.png
:width:
10
0%
:width:
8
0%
:align: center
:align: center
:alt: Lightx2v
:alt: Lightx2v
:class: no-scaled-link
:class: no-scaled-link
.. raw:: html
.. raw:: html
<p style="text-align:center">
<div align="center" style="font-family: charter;">
<strong>一个轻量级的视频生成推理框架
</strong>
<a href="https://opensource.org/licenses/Apache-2.0"><img src="https://img.shields.io/badge/License-Apache_2.0-blue.svg" alt="License"></a>
<a href="https://deepwiki.com/ModelTC/lightx2v"><img src="https://deepwiki.com/badge.svg" alt="Ask DeepWiki"></a>
<a href="https://lightx2v-en.readthedocs.io/en/latest"><img src="https://img.shields.io/badge/docs-English-99cc2" alt="Doc"></a>
<a href="https://lightx2v-zhcn.readthedocs.io/zh-cn/latest"><img src="https://img.shields.io/badge/文档-中文-99cc2" alt="Doc"></a>
<a href="https://hub.docker.com/r/lightx2v/lightx2v/tags"><img src="https://badgen.net/badge/icon/docker?icon=docker&label" alt="Docker"></a>
</div>
<div align="center" style="font-family: charter;">
<strong>LightX2V: 一个轻量级的视频生成推理框架</strong>
</div>
LightX2V 是一个轻量级的视频生成推理框架,旨在提供一个利用多种先进的视频生成推理技术的推理工具。该框架作为统一的推理平台,支持不同模型的文本到视频(T2V)和图像到视频(I2V)等生成任务。X2V 表示将不同的输入模态(X,如文本或图像)转换(to)为视频输出(V)。
GitHub: https://github.com/ModelTC/lightx2v
HuggingFace: https://huggingface.co/lightx2v
文档列表
文档列表
-------------
-------------
...
@@ -22,6 +38,7 @@
...
@@ -22,6 +38,7 @@
:caption: 快速入门
:caption: 快速入门
快速入门 <getting_started/quickstart.md>
快速入门 <getting_started/quickstart.md>
基准测试 <getting_started/benchmark.md>
.. toctree::
.. toctree::
:maxdepth: 1
:maxdepth: 1
...
@@ -32,6 +49,8 @@
...
@@ -32,6 +49,8 @@
注意力机制 <method_tutorials/attention.md>
注意力机制 <method_tutorials/attention.md>
参数卸载 <method_tutorials/offload.md>
参数卸载 <method_tutorials/offload.md>
并行推理 <method_tutorials/parallel.md>
并行推理 <method_tutorials/parallel.md>
步数蒸馏 <method_tutorials/step_distill.md>
自回归蒸馏 <method_tutorials/autoregressive_distill.md>
.. toctree::
.. toctree::
:maxdepth: 1
:maxdepth: 1
...
@@ -39,14 +58,8 @@
...
@@ -39,14 +58,8 @@
低延迟场景部署 <deploy_guides/for_low_latency.md>
低延迟场景部署 <deploy_guides/for_low_latency.md>
低资源场景部署 <deploy_guides/for_low_resource.md>
低资源场景部署 <deploy_guides/for_low_resource.md>
Lora模型部署 <deploy_guides/lora_deploy.md>
服务化部署 <deploy_guides/deploy_service.md>
服务化部署 <deploy_guides/deploy_service.md>
g
radio部署 <deploy_guides/deploy_gradio.md>
G
radio部署 <deploy_guides/deploy_gradio.md>
c
omfy
ui
部署 <deploy_guides/deploy_comfyui.md>
C
omfy
UI
部署 <deploy_guides/deploy_comfyui.md>
本地windows电脑部署 <deploy_guides/deploy_local_windows.md>
本地windows电脑部署 <deploy_guides/deploy_local_windows.md>
.. Indices and tables
.. ==================
.. * :ref:`genindex`
.. * :ref:`modindex`
docs/ZH_CN/source/method_tutorials/attention.md
View file @
ae089db4
#
注意力机制
#
🎯 DiT 模型中的注意力类型配置说明
xxx
当前 DiT 模型在
`LightX2V`
中三个地方使用到了注意力,每个注意力可以分别配置底层注意力库类型。
---
## 使用注意力的位置
1.
**图像的自注意力(Self-Attention)**
-
配置参数:
`self_attn_1_type`
2.
**图像与提示词(Text)之间的交叉注意力(Cross-Attention)**
-
配置参数:
`cross_attn_1_type`
3.
**I2V 模式下图像与参考图(Reference)之间的交叉注意力**
-
配置参数:
`cross_attn_2_type`
---
## 🚀 支持的注意力库(Backend)
| 名称 | 类型名称 | GitHub 链接 |
|--------------------|------------------|-------------|
| Flash Attention 2 |
`flash_attn2`
|
[
flash-attention v2
](
https://github.com/Dao-AILab/flash-attention
)
|
| Flash Attention 3 |
`flash_attn3`
|
[
flash-attention v3
](
https://github.com/Dao-AILab/flash-attention
)
|
| Sage Attention 2 |
`sage_attn2`
|
[
SageAttention
](
https://github.com/thu-ml/SageAttention
)
|
| Radial Attention |
`radial_attn`
|
[
Radial Attention
](
https://github.com/mit-han-lab/radial-attention
)
|
| Sparge Attention |
`sparge_ckpt`
|
[
Sparge Attention
](
https://github.com/thu-ml/SpargeAttn
)
|
---
## 🛠️ 配置示例
在
`wan_i2v.json`
配置文件中,可以通过如下方式指定使用的注意力类型:
```
json
{
"self_attn_1_type"
:
"radial_attn"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
}
```
如需更换为其他类型,只需将对应值替换为上述表格中的类型名称即可。
tips: radial_attn因为稀疏算法原理的限制只能用在self attention
---
对于 Sparge Attention 配置参考
`wan_t2v_sparge.json`
文件:
Sparge Attention是需要后一个训练的权重
```
json
{
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
"sparge"
:
true
,
"sparge_ckpt"
:
"/path/to/sparge_wan2.1_t2v_1.3B.pt"
}
```
---
如需进一步定制注意力机制的行为,请参考各注意力库的官方文档或实现代码。
docs/ZH_CN/source/method_tutorials/autoregressive_distill.md
0 → 100644
View file @
ae089db4
# 自回归蒸馏
xxx
docs/ZH_CN/source/method_tutorials/cache.md
View file @
ae089db4
# 特征缓存
# 特征缓存
xxx
## 缓存加速算法
-
在扩散模型的推理过程中,缓存复用是一种重要的加速算法。
-
其核心思想是在部分时间步跳过冗余计算,通过复用历史缓存结果提升推理效率。
-
算法的关键在于如何决策在哪些时间步进行缓存复用,通常基于模型状态变化或误差阈值动态判断。
-
在推理过程中,需要缓存如中间特征、残差、注意力输出等关键内容。当进入可复用时间步时,直接利用已缓存的内容,通过泰勒展开等近似方法重构当前输出,从而减少重复计算,实现高效推理。
### TeaCache
`TeaCache`
的核心思想是通过对相邻时间步输入的
**相对L1**
距离进行累加,当累计距离达到设定阈值时,判定当前时间步可以进行缓存复用。
-
具体来说,算法在每一步推理时计算当前输入与上一步输入的相对L1距离,并将其累加。
-
当累计距离超过阈值,说明模型状态发生了足够的变化,则直接复用最近一次缓存的内容,跳过部分冗余计算。这样可以显著减少模型的前向计算次数,提高推理速度。
实际效果上,TeaCache 在保证生成质量的前提下,实现了明显的加速。加速前后的视频对比如下:
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:58s | 单卡H200推理耗时:17.9s |
| !
[
加速前效果
](
../../../../assets/gifs/1.gif
)
| !
[
加速后效果
](
../../../../assets/gifs/2.gif
)
|
-
加速比为:
**3.24**
-
config:
[
wan_t2v_1_3b_tea_480p.json
](
https://github.com/ModelTC/lightx2v/tree/main/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json
)
-
参考论文:
[
https://arxiv.org/abs/2411.19108
](
https://arxiv.org/abs/2411.19108
)
### TaylorSeer Cache
`TaylorSeer Cache`
的核心在于利用泰勒公式对缓存内容进行再次计算,作为缓存复用时间步的残差补偿。具体做法是在缓存复用的时间步,不仅简单地复用历史缓存,还通过泰勒展开对当前输出进行近似重构。这样可以在减少计算量的同时,进一步提升输出的准确性。泰勒展开能够有效捕捉模型状态的微小变化,使得缓存复用带来的误差得到补偿,从而在加速的同时保证生成质量。
`TaylorSeer Cache`
适用于对输出精度要求较高的场景,能够在缓存复用的基础上进一步提升模型推理的表现。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:57.7s | 单卡H200推理耗时:41.3s |
| !
[
加速前效果
](
../../../../assets/gifs/3.gif
)
| !
[
加速后效果
](
../../../../assets/gifs/4.gif
)
|
-
加速比为:
**1.39**
-
config:
[
wan_t2v_taylorseer
](
https://github.com/ModelTC/lightx2v/tree/main/configs/caching/taylorseer/wan_t2v_taylorseer.json
)
-
参考论文:
[
https://arxiv.org/abs/2503.06923
](
https://arxiv.org/abs/2503.06923
)
### AdaCache
`AdaCache`
的核心思想是根据指定block块中的部分缓存内容,动态调整缓存复用的步长。
-
算法会分析相邻两个时间步在特定 block 内的特征差异,根据差异大小自适应地决定下一个缓存复用的时间步间隔。
-
当模型状态变化较小时,步长自动加大,减少缓存更新频率;当状态变化较大时,步长缩小,保证输出质量。
这样可以根据实际推理过程中的动态变化,灵活调整缓存策略,实现更高效的加速和更优的生成效果。AdaCache 适合对推理速度和生成质量都有较高要求的应用场景。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:227s | 单卡H200推理耗时:83s |
| !
[
加速前效果
](
../../../../assets/gifs/5.gif
)
| !
[
加速后效果
](
../../../../assets/gifs/6.gif
)
|
-
加速比为:
**2.73**
-
config:
[
wan_i2v_ada
](
https://github.com/ModelTC/lightx2v/tree/main/configs/caching/adacache/wan_i2v_ada.json
)
-
参考论文:
[
https://arxiv.org/abs/2411.02397
](
https://arxiv.org/abs/2411.02397
)
### CustomCache
`CustomCache`
综合了
`TeaCache`
和
`TaylorSeer Cache`
的优势。
-
它结合了
`TeaCache`
在缓存决策上的实时性和合理性,通过动态阈值判断何时进行缓存复用.
-
同时利用
`TaylorSeer`
的泰勒展开方法对已缓存内容进行利用。
这样不仅能够高效地决定缓存复用的时机,还能最大程度地利用缓存内容,提升输出的准确性和生成质量。实际测试表明,
`CustomCache`
在多个内容生成任务上,生成的视频质量优于单独使用
`TeaCache、TaylorSeer Cache`
或
`AdaCache`
的方案,是目前综合性能最优的缓存加速算法之一。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:57.9s | 单卡H200推理耗时:16.6s |
| !
[
加速前效果
](
../../../../assets/gifs/7.gif
)
| !
[
加速后效果
](
../../../../assets/gifs/8.gif
)
|
-
加速比为:
**3.49**
-
config:
[
wan_t2v_custom_1_3b
](
https://github.com/ModelTC/lightx2v/tree/main/configs/caching/custom/wan_t2v_custom_1_3b.json
)
## 使用方式
特征缓存的config文件在
[
这里
](
https://github.com/ModelTC/lightx2v/tree/main/configs/caching
)
通过指定--config_json到具体的config文件,即可以测试不同的cache算法
[
这里
](
https://github.com/ModelTC/lightx2v/tree/main/scripts/cache
)
有一些运行脚本供使用。
docs/ZH_CN/source/method_tutorials/step_distill.md
0 → 100644
View file @
ae089db4
# 步数蒸馏
xxx
lightx2v/attentions/common/radial_attn.py
View file @
ae089db4
...
@@ -2,6 +2,10 @@ import torch
...
@@ -2,6 +2,10 @@ import torch
try
:
try
:
import
flashinfer
import
flashinfer
from
packaging
import
version
flashinfer_version
=
version
.
parse
(
flashinfer
.
__version__
)
has_o_dtype
=
flashinfer_version
>=
version
.
parse
(
"0.2.6.post1"
)
except
ImportError
:
except
ImportError
:
flashinfer
=
None
flashinfer
=
None
...
@@ -29,7 +33,8 @@ def radial_attn(
...
@@ -29,7 +33,8 @@ def radial_attn(
indptr
=
get_indptr_from_mask
(
mask
,
query
)
indptr
=
get_indptr_from_mask
(
mask
,
query
)
indices
=
get_indices_from_mask
(
mask
,
query
)
indices
=
get_indices_from_mask
(
mask
,
query
)
bsr_wrapper
.
plan
(
kwargs
=
dict
(
indptr
=
indptr
,
indptr
=
indptr
,
indices
=
indices
,
indices
=
indices
,
M
=
seqlen
,
M
=
seqlen
,
...
@@ -43,6 +48,10 @@ def radial_attn(
...
@@ -43,6 +48,10 @@ def radial_attn(
kv_data_type
=
key
.
dtype
,
kv_data_type
=
key
.
dtype
,
use_fp16_qk_reduction
=
True
,
use_fp16_qk_reduction
=
True
,
)
)
if
has_o_dtype
:
kwargs
[
"o_data_type"
]
=
query
.
dtype
bsr_wrapper
.
plan
(
**
kwargs
)
o
=
bsr_wrapper
.
run
(
query
,
key
,
value
)
o
=
bsr_wrapper
.
run
(
query
,
key
,
value
)
...
...
lightx2v/common/offload/manager.py
View file @
ae089db4
...
@@ -121,7 +121,8 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
...
@@ -121,7 +121,8 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Disk worker thread error:
{
e
}
"
)
logger
.
error
(
f
"Disk worker thread error:
{
e
}
"
)
def
_async_prefetch_block
(
self
,
weights
):
def
_async_prefetch_block
(
self
,
blocks
,
next_block_idx
=
None
):
if
next_block_idx
is
None
:
next_block_idx
=
self
.
pin_memory_buffer
.
get_max_block_index
()
next_block_idx
=
self
.
pin_memory_buffer
.
get_max_block_index
()
if
next_block_idx
<
0
:
if
next_block_idx
<
0
:
...
@@ -137,7 +138,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
...
@@ -137,7 +138,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
with
self
.
task_lock
:
with
self
.
task_lock
:
self
.
pending_tasks
[
obj_key
]
=
True
self
.
pending_tasks
[
obj_key
]
=
True
phase
=
weights
.
blocks
[
next_block_idx
].
compute_phases
[
phase_idx
]
phase
=
blocks
[
next_block_idx
].
compute_phases
[
phase_idx
]
priority_key
=
(
next_block_idx
,
phase_idx
)
priority_key
=
(
next_block_idx
,
phase_idx
)
self
.
disk_task_queue
.
put
((
priority_key
,
(
next_block_idx
,
phase_idx
,
phase
)))
self
.
disk_task_queue
.
put
((
priority_key
,
(
next_block_idx
,
phase_idx
,
phase
)))
...
@@ -149,20 +150,20 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
...
@@ -149,20 +150,20 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
with
self
.
task_lock
:
with
self
.
task_lock
:
self
.
pending_tasks
[
obj_key
]
=
True
self
.
pending_tasks
[
obj_key
]
=
True
block
=
weights
.
blocks
[
next_block_idx
]
block
=
blocks
[
next_block_idx
]
self
.
disk_task_queue
.
put
((
obj_key
,
(
next_block_idx
,
block
)))
self
.
disk_task_queue
.
put
((
obj_key
,
(
next_block_idx
,
block
)))
def
_sync_prefetch_block
(
self
,
weight
s
):
def
_sync_prefetch_block
(
self
,
block
s
):
block_idx
=
0
block_idx
=
0
while
not
self
.
pin_memory_buffer
.
is_nearly_full
():
while
not
self
.
pin_memory_buffer
.
is_nearly_full
():
if
self
.
offload_gra
==
"phase"
:
if
self
.
offload_gra
==
"phase"
:
for
phase_idx
in
range
(
self
.
phases_num
):
for
phase_idx
in
range
(
self
.
phases_num
):
phase
=
weights
.
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
phase
=
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
logger
.
info
(
f
"Synchronous loading: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
logger
.
info
(
f
"Synchronous loading: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
phase
.
load_from_disk
()
phase
.
load_from_disk
()
self
.
pin_memory_buffer
.
push
((
block_idx
,
phase_idx
),
phase
)
self
.
pin_memory_buffer
.
push
((
block_idx
,
phase_idx
),
phase
)
else
:
else
:
block
=
weights
.
blocks
[
block_idx
]
block
=
blocks
[
block_idx
]
logger
.
info
(
f
"Synchronous loading: block=
{
block_idx
}
"
)
logger
.
info
(
f
"Synchronous loading: block=
{
block_idx
}
"
)
for
phase
in
block
.
compute_phases
:
for
phase
in
block
.
compute_phases
:
phase
.
load_from_disk
()
phase
.
load_from_disk
()
...
@@ -170,11 +171,11 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
...
@@ -170,11 +171,11 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
block_idx
+=
1
block_idx
+=
1
def
prefetch_weights_from_disk
(
self
,
weight
s
):
def
prefetch_weights_from_disk
(
self
,
block
s
):
if
self
.
initial_prefetch_done
:
if
self
.
initial_prefetch_done
:
return
return
self
.
_sync_prefetch_block
(
weight
s
)
self
.
_sync_prefetch_block
(
block
s
)
self
.
initial_prefetch_done
=
True
self
.
initial_prefetch_done
=
True
def
prefetch_weights
(
self
,
block_idx
,
blocks
):
def
prefetch_weights
(
self
,
block_idx
,
blocks
):
...
@@ -193,7 +194,15 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
...
@@ -193,7 +194,15 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if
time
.
time
()
-
start_time
>
5
:
if
time
.
time
()
-
start_time
>
5
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
"
)
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
"
)
else
:
else
:
logger
.
info
(
"Not find prefetch block={block_idx} task. This is a bug."
)
logger
.
info
(
"Not find prefetch block={block_idx} task."
)
logger
.
info
(
"Sync prefetch block={block_idx}."
)
self
.
_async_prefetch_block
(
blocks
,
block_idx
)
start_time
=
time
.
time
()
for
phase_idx
in
self
.
phases_num
:
while
not
self
.
pin_memory_buffer
.
exists
((
block_idx
,
phase_idx
)):
time
.
sleep
(
0.001
)
if
time
.
time
()
-
start_time
>
15
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
block
=
self
.
pin_memory_buffer
.
get
(
obj_key
)
block
=
self
.
pin_memory_buffer
.
get
(
obj_key
)
...
@@ -224,7 +233,14 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
...
@@ -224,7 +233,14 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if
time
.
time
()
-
start_time
>
5
:
if
time
.
time
()
-
start_time
>
5
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
else
:
else
:
logger
.
info
(
"Not find prefetch block={block_idx}, phase={phase_idx} task. This is a bug."
)
logger
.
info
(
f
"Not find block=
{
block_idx
}
, phase=
{
phase_idx
}
task."
)
logger
.
info
(
f
"Sync prefetch block=
{
block_idx
}
, phase=
{
phase_idx
}
."
)
self
.
_async_prefetch_block
(
blocks
,
block_idx
)
start_time
=
time
.
time
()
while
not
self
.
pin_memory_buffer
.
exists
((
block_idx
,
phase_idx
)):
time
.
sleep
(
0.001
)
if
time
.
time
()
-
start_time
>
5
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
phase
=
self
.
pin_memory_buffer
.
get
(
obj_key
)
phase
=
self
.
pin_memory_buffer
.
get
(
obj_key
)
...
...
lightx2v/models/input_encoders/hf/q_linear.py
View file @
ae089db4
...
@@ -2,14 +2,9 @@ import torch
...
@@ -2,14 +2,9 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
try
:
import
q8_kernels.functional
as
Q8F
except
ImportError
:
Q8F
=
None
class
QuantLinearInt8
(
nn
.
Module
):
class
QuantLinearInt8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
()
super
().
__init__
()
self
.
in_features
=
in_features
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
out_features
=
out_features
...
@@ -18,7 +13,7 @@ class QuantLinearInt8(nn.Module):
...
@@ -18,7 +13,7 @@ class QuantLinearInt8(nn.Module):
self
.
register_buffer
(
"weight_scale"
,
torch
.
empty
((
out_features
,
1
),
dtype
=
torch
.
float32
))
self
.
register_buffer
(
"weight_scale"
,
torch
.
empty
((
out_features
,
1
),
dtype
=
torch
.
float32
))
if
bias
:
if
bias
:
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
dtype
))
else
:
else
:
self
.
register_buffer
(
"bias"
,
None
)
self
.
register_buffer
(
"bias"
,
None
)
...
@@ -44,18 +39,31 @@ class QuantLinearInt8(nn.Module):
...
@@ -44,18 +39,31 @@ class QuantLinearInt8(nn.Module):
)
)
return
output_tensor
.
unsqueeze
(
0
)
return
output_tensor
.
unsqueeze
(
0
)
def
_apply
(
self
,
fn
):
for
module
in
self
.
children
():
module
.
_apply
(
fn
)
def
maybe_cast
(
t
):
if
t
is
not
None
and
t
.
device
!=
fn
(
t
).
device
:
return
fn
(
t
)
return
t
self
.
weight
=
maybe_cast
(
self
.
weight
)
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
return
self
class
QuantLinearFp8
(
nn
.
Module
):
class
QuantLinearFp8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
()
super
().
__init__
()
self
.
in_features
=
in_features
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
out_features
=
out_features
self
.
register_buffer
(
"weight"
,
torch
.
empty
((
out_features
,
in_features
),
dtype
=
torch
.
float8_e4m3fn
))
self
.
register_buffer
(
"weight"
,
torch
.
empty
((
out_features
,
in_features
),
dtype
=
torch
.
float8_e4m3fn
))
self
.
register_buffer
(
"weight_scale"
,
torch
.
empty
((
out_features
,
1
),
dtype
=
torch
.
float32
))
self
.
register_buffer
(
"weight_scale"
,
torch
.
empty
((
out_features
,
1
),
dtype
=
torch
.
float32
))
if
bias
:
if
bias
:
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
dtype
))
else
:
else
:
self
.
register_buffer
(
"bias"
,
None
)
self
.
register_buffer
(
"bias"
,
None
)
...
@@ -65,7 +73,6 @@ class QuantLinearFp8(nn.Module):
...
@@ -65,7 +73,6 @@ class QuantLinearFp8(nn.Module):
def
forward
(
self
,
input_tensor
):
def
forward
(
self
,
input_tensor
):
input_tensor
=
input_tensor
.
squeeze
(
0
)
input_tensor
=
input_tensor
.
squeeze
(
0
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float8_e4m3fn
)
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
0
])
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
0
])
dtype
=
input_tensor
.
dtype
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
device
=
input_tensor
.
device
...
@@ -79,4 +86,19 @@ class QuantLinearFp8(nn.Module):
...
@@ -79,4 +86,19 @@ class QuantLinearFp8(nn.Module):
self
.
weight_scale
.
float
(),
self
.
weight_scale
.
float
(),
self
.
bias
,
self
.
bias
,
)
)
return
output_tensor
.
unsqueeze
(
0
)
return
output_tensor
.
unsqueeze
(
0
)
def
_apply
(
self
,
fn
):
for
module
in
self
.
children
():
module
.
_apply
(
fn
)
def
maybe_cast
(
t
):
if
t
is
not
None
and
t
.
device
!=
fn
(
t
).
device
:
return
fn
(
t
)
return
t
self
.
weight
=
maybe_cast
(
self
.
weight
)
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
return
self
lightx2v/models/input_encoders/hf/t5/model.py
View file @
ae089db4
...
@@ -51,11 +51,11 @@ class GELU(nn.Module):
...
@@ -51,11 +51,11 @@ class GELU(nn.Module):
class
T5LayerNorm
(
nn
.
Module
):
class
T5LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-6
):
def
__init__
(
self
,
dim
,
eps
=
1e-6
,
dtype
=
torch
.
float16
):
super
(
T5LayerNorm
,
self
).
__init__
()
super
(
T5LayerNorm
,
self
).
__init__
()
self
.
dim
=
dim
self
.
dim
=
dim
self
.
eps
=
eps
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
dim
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
dim
,
dtype
=
dtype
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
x
*
torch
.
rsqrt
(
x
.
float
().
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
x
=
x
*
torch
.
rsqrt
(
x
.
float
().
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
...
@@ -65,7 +65,7 @@ class T5LayerNorm(nn.Module):
...
@@ -65,7 +65,7 @@ class T5LayerNorm(nn.Module):
class
T5Attention
(
nn
.
Module
):
class
T5Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_attn
,
num_heads
,
dropout
=
0.1
,
quantized
=
False
,
quant_scheme
=
None
):
def
__init__
(
self
,
dim
,
dim_attn
,
num_heads
,
dropout
=
0.1
,
quantized
=
False
,
quant_scheme
=
None
,
dtype
=
torch
.
bfloat16
):
assert
dim_attn
%
num_heads
==
0
assert
dim_attn
%
num_heads
==
0
super
(
T5Attention
,
self
).
__init__
()
super
(
T5Attention
,
self
).
__init__
()
self
.
dim
=
dim
self
.
dim
=
dim
...
@@ -82,10 +82,10 @@ class T5Attention(nn.Module):
...
@@ -82,10 +82,10 @@ class T5Attention(nn.Module):
linear_cls
=
nn
.
Linear
linear_cls
=
nn
.
Linear
# layers
# layers
self
.
q
=
linear_cls
(
dim
,
dim_attn
,
bias
=
False
)
self
.
q
=
linear_cls
(
dim
,
dim_attn
,
bias
=
False
,
dtype
=
dtype
)
self
.
k
=
linear_cls
(
dim
,
dim_attn
,
bias
=
False
)
self
.
k
=
linear_cls
(
dim
,
dim_attn
,
bias
=
False
,
dtype
=
dtype
)
self
.
v
=
linear_cls
(
dim
,
dim_attn
,
bias
=
False
)
self
.
v
=
linear_cls
(
dim
,
dim_attn
,
bias
=
False
,
dtype
=
dtype
)
self
.
o
=
linear_cls
(
dim_attn
,
dim
,
bias
=
False
)
self
.
o
=
linear_cls
(
dim_attn
,
dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
pos_bias
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
pos_bias
=
None
):
...
@@ -125,7 +125,7 @@ class T5Attention(nn.Module):
...
@@ -125,7 +125,7 @@ class T5Attention(nn.Module):
class
T5FeedForward
(
nn
.
Module
):
class
T5FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_ffn
,
dropout
=
0.1
,
quantized
=
False
,
quant_scheme
=
None
):
def
__init__
(
self
,
dim
,
dim_ffn
,
dropout
=
0.1
,
quantized
=
False
,
quant_scheme
=
None
,
dtype
=
torch
.
bfloat16
):
super
(
T5FeedForward
,
self
).
__init__
()
super
(
T5FeedForward
,
self
).
__init__
()
self
.
dim
=
dim
self
.
dim
=
dim
self
.
dim_ffn
=
dim_ffn
self
.
dim_ffn
=
dim_ffn
...
@@ -138,9 +138,9 @@ class T5FeedForward(nn.Module):
...
@@ -138,9 +138,9 @@ class T5FeedForward(nn.Module):
else
:
else
:
linear_cls
=
nn
.
Linear
linear_cls
=
nn
.
Linear
# layers
# layers
self
.
gate
=
nn
.
Sequential
(
linear_cls
(
dim
,
dim_ffn
,
bias
=
False
),
GELU
())
self
.
gate
=
nn
.
Sequential
(
linear_cls
(
dim
,
dim_ffn
,
bias
=
False
,
dtype
=
dtype
),
GELU
())
self
.
fc1
=
linear_cls
(
dim
,
dim_ffn
,
bias
=
False
)
self
.
fc1
=
linear_cls
(
dim
,
dim_ffn
,
bias
=
False
,
dtype
=
dtype
)
self
.
fc2
=
linear_cls
(
dim_ffn
,
dim
,
bias
=
False
)
self
.
fc2
=
linear_cls
(
dim_ffn
,
dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -152,7 +152,7 @@ class T5FeedForward(nn.Module):
...
@@ -152,7 +152,7 @@ class T5FeedForward(nn.Module):
class
T5SelfAttention
(
nn
.
Module
):
class
T5SelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
=
True
,
dropout
=
0.1
,
quantized
=
False
,
quant_scheme
=
None
):
def
__init__
(
self
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
=
True
,
dropout
=
0.1
,
quantized
=
False
,
quant_scheme
=
None
,
dtype
=
torch
.
bfloat16
):
super
(
T5SelfAttention
,
self
).
__init__
()
super
(
T5SelfAttention
,
self
).
__init__
()
self
.
dim
=
dim
self
.
dim
=
dim
self
.
dim_attn
=
dim_attn
self
.
dim_attn
=
dim_attn
...
@@ -162,11 +162,11 @@ class T5SelfAttention(nn.Module):
...
@@ -162,11 +162,11 @@ class T5SelfAttention(nn.Module):
self
.
shared_pos
=
shared_pos
self
.
shared_pos
=
shared_pos
# layers
# layers
self
.
norm1
=
T5LayerNorm
(
dim
)
self
.
norm1
=
T5LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
attn
=
T5Attention
(
dim
,
dim_attn
,
num_heads
,
dropout
,
quantized
,
quant_scheme
)
self
.
attn
=
T5Attention
(
dim
,
dim_attn
,
num_heads
,
dropout
,
quantized
,
quant_scheme
,
dtype
)
self
.
norm2
=
T5LayerNorm
(
dim
)
self
.
norm2
=
T5LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
ffn
=
T5FeedForward
(
dim
,
dim_ffn
,
dropout
,
quantized
,
quant_scheme
)
self
.
ffn
=
T5FeedForward
(
dim
,
dim_ffn
,
dropout
,
quantized
,
quant_scheme
,
dtype
=
dtype
)
self
.
pos_embedding
=
None
if
shared_pos
else
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
True
)
self
.
pos_embedding
=
None
if
shared_pos
else
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
True
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
mask
=
None
,
pos_bias
=
None
):
def
forward
(
self
,
x
,
mask
=
None
,
pos_bias
=
None
):
e
=
pos_bias
if
self
.
shared_pos
else
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
e
=
pos_bias
if
self
.
shared_pos
else
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
...
@@ -212,7 +212,7 @@ class T5CrossAttention(nn.Module):
...
@@ -212,7 +212,7 @@ class T5CrossAttention(nn.Module):
class
T5RelativeEmbedding
(
nn
.
Module
):
class
T5RelativeEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
num_buckets
,
num_heads
,
bidirectional
,
max_dist
=
128
):
def
__init__
(
self
,
num_buckets
,
num_heads
,
bidirectional
,
dtype
=
torch
.
bfloat16
,
max_dist
=
128
):
super
(
T5RelativeEmbedding
,
self
).
__init__
()
super
(
T5RelativeEmbedding
,
self
).
__init__
()
self
.
num_buckets
=
num_buckets
self
.
num_buckets
=
num_buckets
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
...
@@ -220,7 +220,7 @@ class T5RelativeEmbedding(nn.Module):
...
@@ -220,7 +220,7 @@ class T5RelativeEmbedding(nn.Module):
self
.
max_dist
=
max_dist
self
.
max_dist
=
max_dist
# layers
# layers
self
.
embedding
=
nn
.
Embedding
(
num_buckets
,
num_heads
)
self
.
embedding
=
nn
.
Embedding
(
num_buckets
,
num_heads
,
dtype
=
dtype
)
def
forward
(
self
,
lq
,
lk
):
def
forward
(
self
,
lq
,
lk
):
device
=
self
.
embedding
.
weight
.
device
device
=
self
.
embedding
.
weight
.
device
...
@@ -252,7 +252,7 @@ class T5RelativeEmbedding(nn.Module):
...
@@ -252,7 +252,7 @@ class T5RelativeEmbedding(nn.Module):
class
T5Encoder
(
nn
.
Module
):
class
T5Encoder
(
nn
.
Module
):
def
__init__
(
self
,
vocab
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_layers
,
num_buckets
,
shared_pos
=
True
,
dropout
=
0.1
,
cpu_offload
=
False
,
quantized
=
False
,
quant_scheme
=
None
):
def
__init__
(
self
,
dtype
,
vocab
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_layers
,
num_buckets
,
shared_pos
=
True
,
dropout
=
0.1
,
cpu_offload
=
False
,
quantized
=
False
,
quant_scheme
=
None
):
super
(
T5Encoder
,
self
).
__init__
()
super
(
T5Encoder
,
self
).
__init__
()
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
...
@@ -266,11 +266,11 @@ class T5Encoder(nn.Module):
...
@@ -266,11 +266,11 @@ class T5Encoder(nn.Module):
self
.
quant_scheme
=
quant_scheme
self
.
quant_scheme
=
quant_scheme
# layers
# layers
self
.
token_embedding
=
vocab
if
isinstance
(
vocab
,
nn
.
Embedding
)
else
nn
.
Embedding
(
vocab
,
dim
)
self
.
token_embedding
=
vocab
.
to
(
dtype
)
if
isinstance
(
vocab
,
nn
.
Embedding
)
else
nn
.
Embedding
(
vocab
,
dim
,
dtype
=
dtype
)
self
.
pos_embedding
=
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
True
)
if
shared_pos
else
None
self
.
pos_embedding
=
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
True
,
dtype
=
dtype
)
if
shared_pos
else
None
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
blocks
=
nn
.
ModuleList
([
T5SelfAttention
(
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
,
dropout
,
quantized
,
quant_scheme
)
for
_
in
range
(
num_layers
)])
self
.
blocks
=
nn
.
ModuleList
([
T5SelfAttention
(
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
,
dropout
,
quantized
,
quant_scheme
,
dtype
)
for
_
in
range
(
num_layers
)])
self
.
norm
=
T5LayerNorm
(
dim
)
self
.
norm
=
T5LayerNorm
(
dim
,
dtype
=
dtype
)
# initialize weights
# initialize weights
# self.apply(init_weights)
# self.apply(init_weights)
...
@@ -443,10 +443,10 @@ def _t5(
...
@@ -443,10 +443,10 @@ def _t5(
# init model
# init model
with
torch
.
device
(
device
):
with
torch
.
device
(
device
):
model
=
model_cls
(
**
kwargs
)
model
=
model_cls
(
dtype
=
dtype
,
**
kwargs
)
# set device
# set device
model
=
model
.
to
(
dtype
=
dtype
,
device
=
device
)
model
=
model
.
to
(
device
=
device
)
return
model
return
model
...
@@ -511,9 +511,10 @@ class T5EncoderModel:
...
@@ -511,9 +511,10 @@ class T5EncoderModel:
.
requires_grad_
(
False
)
.
requires_grad_
(
False
)
)
)
logger
.
info
(
f
"Loading weights from
{
self
.
checkpoint_path
}
"
)
logger
.
info
(
f
"Start Loading weights from
{
self
.
checkpoint_path
}
"
)
model
.
load_state_dict
(
torch
.
load
(
self
.
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
))
model
.
load_state_dict
(
torch
.
load
(
self
.
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
))
logger
.
info
(
f
"End Loading weights from
{
self
.
checkpoint_path
}
"
)
self
.
model
=
model
self
.
model
=
model
if
shard_fn
is
not
None
:
if
shard_fn
is
not
None
:
self
.
model
=
shard_fn
(
self
.
model
,
sync_module_states
=
False
)
self
.
model
=
shard_fn
(
self
.
model
,
sync_module_states
=
False
)
...
...
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
ae089db4
...
@@ -50,7 +50,7 @@ class LayerNorm(nn.LayerNorm):
...
@@ -50,7 +50,7 @@ class LayerNorm(nn.LayerNorm):
class
SelfAttention
(
nn
.
Module
):
class
SelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
causal
=
False
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
quantized
=
False
,
quant_scheme
=
None
):
def
__init__
(
self
,
dim
,
num_heads
,
causal
=
False
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
quantized
=
False
,
quant_scheme
=
None
,
dtype
=
None
):
assert
dim
%
num_heads
==
0
assert
dim
%
num_heads
==
0
super
().
__init__
()
super
().
__init__
()
self
.
dim
=
dim
self
.
dim
=
dim
...
@@ -69,8 +69,8 @@ class SelfAttention(nn.Module):
...
@@ -69,8 +69,8 @@ class SelfAttention(nn.Module):
else
:
else
:
linear_cls
=
nn
.
Linear
linear_cls
=
nn
.
Linear
self
.
to_qkv
=
linear_cls
(
dim
,
dim
*
3
)
self
.
to_qkv
=
linear_cls
(
dim
,
dim
*
3
,
dtype
=
dtype
)
self
.
proj
=
linear_cls
(
dim
,
dim
)
self
.
proj
=
linear_cls
(
dim
,
dim
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
...
@@ -108,7 +108,21 @@ class SwiGLU(nn.Module):
...
@@ -108,7 +108,21 @@ class SwiGLU(nn.Module):
class
AttentionBlock
(
nn
.
Module
):
class
AttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
mlp_ratio
,
num_heads
,
post_norm
=
False
,
causal
=
False
,
activation
=
"quick_gelu"
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
norm_eps
=
1e-5
,
quantized
=
False
,
quant_scheme
=
None
):
def
__init__
(
self
,
dim
,
mlp_ratio
,
num_heads
,
post_norm
=
False
,
causal
=
False
,
activation
=
"quick_gelu"
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
norm_eps
=
1e-5
,
quantized
=
False
,
quant_scheme
=
None
,
dtype
=
torch
.
float16
,
):
assert
activation
in
[
"quick_gelu"
,
"gelu"
,
"swi_glu"
]
assert
activation
in
[
"quick_gelu"
,
"gelu"
,
"swi_glu"
]
super
().
__init__
()
super
().
__init__
()
self
.
dim
=
dim
self
.
dim
=
dim
...
@@ -127,13 +141,18 @@ class AttentionBlock(nn.Module):
...
@@ -127,13 +141,18 @@ class AttentionBlock(nn.Module):
else
:
else
:
linear_cls
=
nn
.
Linear
linear_cls
=
nn
.
Linear
self
.
norm1
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
self
.
norm1
=
LayerNorm
(
dim
,
eps
=
norm_eps
,
dtype
=
dtype
)
self
.
attn
=
SelfAttention
(
dim
,
num_heads
,
causal
,
attn_dropout
,
proj_dropout
,
quantized
,
quant_scheme
)
self
.
attn
=
SelfAttention
(
dim
,
num_heads
,
causal
,
attn_dropout
,
proj_dropout
,
quantized
,
quant_scheme
,
dtype
)
self
.
norm2
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
self
.
norm2
=
LayerNorm
(
dim
,
eps
=
norm_eps
,
dtype
=
dtype
)
if
activation
==
"swi_glu"
:
if
activation
==
"swi_glu"
:
self
.
mlp
=
SwiGLU
(
dim
,
int
(
dim
*
mlp_ratio
))
self
.
mlp
=
SwiGLU
(
dim
,
int
(
dim
*
mlp_ratio
)
,
dtype
=
dtype
)
else
:
else
:
self
.
mlp
=
nn
.
Sequential
(
linear_cls
(
dim
,
int
(
dim
*
mlp_ratio
)),
QuickGELU
()
if
activation
==
"quick_gelu"
else
nn
.
GELU
(),
linear_cls
(
int
(
dim
*
mlp_ratio
),
dim
),
nn
.
Dropout
(
proj_dropout
))
self
.
mlp
=
nn
.
Sequential
(
linear_cls
(
dim
,
int
(
dim
*
mlp_ratio
),
dtype
=
dtype
),
QuickGELU
()
if
activation
==
"quick_gelu"
else
nn
.
GELU
(),
linear_cls
(
int
(
dim
*
mlp_ratio
),
dim
,
dtype
=
dtype
),
nn
.
Dropout
(
proj_dropout
),
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
self
.
post_norm
:
if
self
.
post_norm
:
...
@@ -146,7 +165,7 @@ class AttentionBlock(nn.Module):
...
@@ -146,7 +165,7 @@ class AttentionBlock(nn.Module):
class
AttentionPool
(
nn
.
Module
):
class
AttentionPool
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
mlp_ratio
,
num_heads
,
activation
=
"gelu"
,
proj_dropout
=
0.0
,
norm_eps
=
1e-5
):
def
__init__
(
self
,
dim
,
mlp_ratio
,
num_heads
,
activation
=
"gelu"
,
proj_dropout
=
0.0
,
norm_eps
=
1e-5
,
dtype
=
torch
.
float16
):
assert
dim
%
num_heads
==
0
assert
dim
%
num_heads
==
0
super
().
__init__
()
super
().
__init__
()
self
.
dim
=
dim
self
.
dim
=
dim
...
@@ -159,11 +178,13 @@ class AttentionPool(nn.Module):
...
@@ -159,11 +178,13 @@ class AttentionPool(nn.Module):
# layers
# layers
gain
=
1.0
/
math
.
sqrt
(
dim
)
gain
=
1.0
/
math
.
sqrt
(
dim
)
self
.
cls_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
1
,
dim
))
self
.
cls_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
1
,
dim
))
self
.
to_q
=
nn
.
Linear
(
dim
,
dim
)
self
.
to_q
=
nn
.
Linear
(
dim
,
dim
,
dtype
=
dtype
)
self
.
to_kv
=
nn
.
Linear
(
dim
,
dim
*
2
)
self
.
to_kv
=
nn
.
Linear
(
dim
,
dim
*
2
,
dtype
=
dtype
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
,
dtype
=
dtype
)
self
.
norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
self
.
norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
,
dtype
=
dtype
)
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
int
(
dim
*
mlp_ratio
)),
QuickGELU
()
if
activation
==
"quick_gelu"
else
nn
.
GELU
(),
nn
.
Linear
(
int
(
dim
*
mlp_ratio
),
dim
),
nn
.
Dropout
(
proj_dropout
))
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
int
(
dim
*
mlp_ratio
),
dtype
=
dtype
),
QuickGELU
()
if
activation
==
"quick_gelu"
else
nn
.
GELU
(),
nn
.
Linear
(
int
(
dim
*
mlp_ratio
),
dim
,
dtype
=
dtype
),
nn
.
Dropout
(
proj_dropout
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
...
@@ -191,6 +212,7 @@ class AttentionPool(nn.Module):
...
@@ -191,6 +212,7 @@ class AttentionPool(nn.Module):
class
VisionTransformer
(
nn
.
Module
):
class
VisionTransformer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
dtype
=
torch
.
float16
,
image_size
=
224
,
image_size
=
224
,
patch_size
=
16
,
patch_size
=
16
,
dim
=
768
,
dim
=
768
,
...
@@ -228,26 +250,26 @@ class VisionTransformer(nn.Module):
...
@@ -228,26 +250,26 @@ class VisionTransformer(nn.Module):
# embeddings
# embeddings
gain
=
1.0
/
math
.
sqrt
(
dim
)
gain
=
1.0
/
math
.
sqrt
(
dim
)
self
.
patch_embedding
=
nn
.
Conv2d
(
3
,
dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
not
pre_norm
)
self
.
patch_embedding
=
nn
.
Conv2d
(
3
,
dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
not
pre_norm
,
dtype
=
dtype
)
if
pool_type
in
(
"token"
,
"token_fc"
):
if
pool_type
in
(
"token"
,
"token_fc"
):
self
.
cls_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
1
,
dim
))
self
.
cls_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
1
,
dim
,
dtype
=
dtype
))
self
.
pos_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
self
.
num_patches
+
(
1
if
pool_type
in
(
"token"
,
"token_fc"
)
else
0
),
dim
))
self
.
pos_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
self
.
num_patches
+
(
1
if
pool_type
in
(
"token"
,
"token_fc"
)
else
0
),
dim
,
dtype
=
dtype
))
self
.
dropout
=
nn
.
Dropout
(
embedding_dropout
)
self
.
dropout
=
nn
.
Dropout
(
embedding_dropout
)
# transformer
# transformer
self
.
pre_norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
if
pre_norm
else
None
self
.
pre_norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
,
dtype
=
dtype
)
if
pre_norm
else
None
self
.
transformer
=
nn
.
Sequential
(
self
.
transformer
=
nn
.
Sequential
(
*
[
AttentionBlock
(
dim
,
mlp_ratio
,
num_heads
,
post_norm
,
False
,
activation
,
attn_dropout
,
proj_dropout
,
norm_eps
,
quantized
,
quant_scheme
)
for
_
in
range
(
num_layers
)]
*
[
AttentionBlock
(
dim
,
mlp_ratio
,
num_heads
,
post_norm
,
False
,
activation
,
attn_dropout
,
proj_dropout
,
norm_eps
,
quantized
,
quant_scheme
,
dtype
)
for
_
in
range
(
num_layers
)]
)
)
self
.
post_norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
self
.
post_norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
,
dtype
=
dtype
)
# head
# head
if
pool_type
==
"token"
:
if
pool_type
==
"token"
:
self
.
head
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
dim
,
out_dim
))
self
.
head
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
dim
,
out_dim
,
dtype
=
dtype
))
elif
pool_type
==
"token_fc"
:
elif
pool_type
==
"token_fc"
:
self
.
head
=
nn
.
Linear
(
dim
,
out_dim
)
self
.
head
=
nn
.
Linear
(
dim
,
out_dim
,
dtype
=
dtype
)
elif
pool_type
==
"attn_pool"
:
elif
pool_type
==
"attn_pool"
:
self
.
head
=
AttentionPool
(
dim
,
mlp_ratio
,
num_heads
,
activation
,
proj_dropout
,
norm_eps
)
self
.
head
=
AttentionPool
(
dim
,
mlp_ratio
,
num_heads
,
activation
,
proj_dropout
,
norm_eps
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
interpolation
=
False
,
use_31_block
=
False
):
def
forward
(
self
,
x
,
interpolation
=
False
,
use_31_block
=
False
):
b
=
x
.
size
(
0
)
b
=
x
.
size
(
0
)
...
@@ -276,6 +298,7 @@ class VisionTransformer(nn.Module):
...
@@ -276,6 +298,7 @@ class VisionTransformer(nn.Module):
class
XLMRobertaCLIP
(
nn
.
Module
):
class
XLMRobertaCLIP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
dtype
=
torch
.
float16
,
embed_dim
=
1024
,
embed_dim
=
1024
,
image_size
=
224
,
image_size
=
224
,
patch_size
=
14
,
patch_size
=
14
,
...
@@ -317,6 +340,7 @@ class XLMRobertaCLIP(nn.Module):
...
@@ -317,6 +340,7 @@ class XLMRobertaCLIP(nn.Module):
# models
# models
self
.
visual
=
VisionTransformer
(
self
.
visual
=
VisionTransformer
(
dtype
=
dtype
,
image_size
=
image_size
,
image_size
=
image_size
,
patch_size
=
patch_size
,
patch_size
=
patch_size
,
dim
=
vision_dim
,
dim
=
vision_dim
,
...
@@ -341,12 +365,11 @@ class XLMRobertaCLIP(nn.Module):
...
@@ -341,12 +365,11 @@ class XLMRobertaCLIP(nn.Module):
def
_clip
(
pretrained
=
False
,
pretrained_name
=
None
,
model_cls
=
XLMRobertaCLIP
,
return_transforms
=
False
,
return_tokenizer
=
False
,
tokenizer_padding
=
"eos"
,
dtype
=
torch
.
float32
,
device
=
"cpu"
,
**
kwargs
):
def
_clip
(
pretrained
=
False
,
pretrained_name
=
None
,
model_cls
=
XLMRobertaCLIP
,
return_transforms
=
False
,
return_tokenizer
=
False
,
tokenizer_padding
=
"eos"
,
dtype
=
torch
.
float32
,
device
=
"cpu"
,
**
kwargs
):
# init a model on device
# init a model on device
with
torch
.
device
(
device
):
with
torch
.
device
(
device
):
model
=
model_cls
(
**
kwargs
)
model
=
model_cls
(
dtype
=
dtype
,
**
kwargs
)
# set device
model
=
model
.
to
(
device
=
device
)
model
=
model
.
to
(
dtype
=
dtype
,
device
=
device
)
output
=
(
model
,)
output
=
(
model
,)
# init transforms
# init transforms
if
return_transforms
:
if
return_transforms
:
# mean and std
# mean and std
...
@@ -395,20 +418,20 @@ class CLIPModel:
...
@@ -395,20 +418,20 @@ class CLIPModel:
else
:
else
:
self
.
checkpoint_path
=
checkpoint_path
self
.
checkpoint_path
=
checkpoint_path
logger
.
info
(
f
"Loading weights from
{
self
.
checkpoint_path
}
"
)
# init model
# init model
self
.
model
,
self
.
transforms
=
clip_xlm_roberta_vit_h_14
(
self
.
model
,
self
.
transforms
=
clip_xlm_roberta_vit_h_14
(
pretrained
=
False
,
return_transforms
=
True
,
return_tokenizer
=
False
,
dtype
=
dtype
,
device
=
device
,
quantized
=
self
.
quantized
,
quant_scheme
=
quant_scheme
pretrained
=
False
,
return_transforms
=
True
,
return_tokenizer
=
False
,
dtype
=
dtype
,
device
=
device
,
quantized
=
self
.
quantized
,
quant_scheme
=
quant_scheme
)
)
self
.
model
=
self
.
model
.
eval
().
requires_grad_
(
False
)
self
.
model
=
self
.
model
.
eval
().
requires_grad_
(
False
)
weight_dict
=
torch
.
load
(
self
.
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
weight_dict
=
torch
.
load
(
self
.
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
keys
=
list
(
weight_dict
.
keys
())
keys
=
list
(
weight_dict
.
keys
())
for
key
in
keys
:
for
key
in
keys
:
if
"textual"
in
key
:
if
"textual"
in
key
:
weight_dict
.
pop
(
key
)
weight_dict
.
pop
(
key
)
logger
.
info
(
f
"Start Loading weights from
{
self
.
checkpoint_path
}
"
)
self
.
model
.
load_state_dict
(
weight_dict
)
self
.
model
.
load_state_dict
(
weight_dict
)
logger
.
info
(
f
"End Loading weights from
{
self
.
checkpoint_path
}
"
)
def
visual
(
self
,
videos
,
args
):
def
visual
(
self
,
videos
,
args
):
if
args
.
cpu_offload
:
if
args
.
cpu_offload
:
...
...
lightx2v/models/networks/wan/audio_adapter.py
View file @
ae089db4
import
flash_attn
try
:
import
flash_attn
except
ModuleNotFoundError
:
flash_attn
=
None
import
math
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
ae089db4
...
@@ -104,7 +104,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -104,7 +104,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
return
x
def
_infer_with_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
)
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
.
blocks
)
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
if
block_idx
==
0
:
if
block_idx
==
0
:
...
@@ -132,7 +132,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -132,7 +132,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if
block_idx
==
self
.
blocks_num
-
1
:
if
block_idx
==
self
.
blocks_num
-
1
:
self
.
weights_stream_mgr
.
pin_memory_buffer
.
pop_front
()
self
.
weights_stream_mgr
.
pin_memory_buffer
.
pop_front
()
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
)
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
del
grid_sizes
,
embed
,
embed0
,
seq_lens
,
freqs
,
context
...
@@ -189,7 +189,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -189,7 +189,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
return
x
def
_infer_with_phases_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
def
_infer_with_phases_lazy_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
audio_dit_blocks
=
None
):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
)
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
.
blocks
)
for
block_idx
in
range
(
weights
.
blocks_num
):
for
block_idx
in
range
(
weights
.
blocks_num
):
for
phase_idx
in
range
(
self
.
weights_stream_mgr
.
phases_num
):
for
phase_idx
in
range
(
self
.
weights_stream_mgr
.
phases_num
):
...
@@ -236,7 +236,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -236,7 +236,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
weights_stream_mgr
.
swap_phases
()
self
.
weights_stream_mgr
.
swap_phases
()
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
)
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
attn_out
,
y_out
,
y
del
attn_out
,
y_out
,
y
...
...
lightx2v/server/api.py
View file @
ae089db4
...
@@ -45,6 +45,11 @@ class ApiServer:
...
@@ -45,6 +45,11 @@ class ApiServer:
self
.
app
.
include_router
(
self
.
files_router
)
self
.
app
.
include_router
(
self
.
files_router
)
self
.
app
.
include_router
(
self
.
service_router
)
self
.
app
.
include_router
(
self
.
service_router
)
def
_write_file_sync
(
self
,
file_path
:
Path
,
content
:
bytes
)
->
None
:
"""同步写入文件到指定路径"""
with
open
(
file_path
,
"wb"
)
as
buffer
:
buffer
.
write
(
content
)
def
_stream_file_response
(
self
,
file_path
:
Path
,
filename
:
str
|
None
=
None
)
->
StreamingResponse
:
def
_stream_file_response
(
self
,
file_path
:
Path
,
filename
:
str
|
None
=
None
)
->
StreamingResponse
:
"""Common file streaming response method"""
"""Common file streaming response method"""
assert
self
.
file_service
is
not
None
,
"File service is not initialized"
assert
self
.
file_service
is
not
None
,
"File service is not initialized"
...
@@ -130,32 +135,30 @@ class ApiServer:
...
@@ -130,32 +135,30 @@ class ApiServer:
video_duration
:
int
=
Form
(
default
=
5
),
video_duration
:
int
=
Form
(
default
=
5
),
):
):
"""Create video generation task via form"""
"""Create video generation task via form"""
# Process uploaded image file
image_path
=
""
assert
self
.
file_service
is
not
None
,
"File service is not initialized"
assert
self
.
file_service
is
not
None
,
"File service is not initialized"
if
image_file
and
image_file
.
filename
:
async
def
save_file_async
(
file
:
UploadFile
,
target_dir
:
Path
)
->
str
:
file_extension
=
Path
(
image_file
.
filename
).
suffix
"""异步保存文件到指定目录"""
if
not
file
or
not
file
.
filename
:
return
""
file_extension
=
Path
(
file
.
filename
).
suffix
unique_filename
=
f
"
{
uuid
.
uuid4
()
}{
file_extension
}
"
unique_filename
=
f
"
{
uuid
.
uuid4
()
}{
file_extension
}
"
imag
e_path
=
self
.
file_service
.
input_ima
ge_dir
/
unique_filename
fil
e_path
=
tar
ge
t
_dir
/
unique_filename
with
open
(
image_path
,
"wb"
)
as
buffer
:
content
=
await
file
.
read
()
content
=
await
image_file
.
read
()
buffer
.
write
(
content
)
image_path
=
str
(
image_path
)
await
asyncio
.
to_thread
(
self
.
_write_file_sync
,
file_path
,
content
)
audio_path
=
""
return
str
(
file_path
)
if
audio_file
and
audio_file
.
filename
:
file_extension
=
Path
(
audio_file
.
filename
).
suffix
unique_filename
=
f
"
{
uuid
.
uuid4
()
}{
file_extension
}
"
audio_path
=
self
.
file_service
.
input_audio_dir
/
unique_filename
with
open
(
audio_path
,
"wb"
)
as
buffer
:
image_path
=
""
content
=
await
audio_file
.
read
()
if
image_file
and
image_file
.
filename
:
buffer
.
write
(
content
)
image_path
=
await
save_file_async
(
image_file
,
self
.
file_service
.
input_image_dir
)
audio_path
=
str
(
audio_path
)
audio_path
=
""
if
audio_file
and
audio_file
.
filename
:
audio_path
=
await
save_file_async
(
audio_file
,
self
.
file_service
.
input_audio_dir
)
message
=
TaskRequest
(
message
=
TaskRequest
(
prompt
=
prompt
,
prompt
=
prompt
,
...
@@ -276,6 +279,12 @@ class ApiServer:
...
@@ -276,6 +279,12 @@ class ApiServer:
"""Get service status"""
"""Get service status"""
return
ServiceStatus
.
get_status_service
()
return
ServiceStatus
.
get_status_service
()
@
self
.
service_router
.
get
(
"/metadata"
,
response_model
=
dict
)
async
def
get_service_metadata
():
"""Get service metadata"""
assert
self
.
inference_service
is
not
None
,
"Inference service is not initialized"
return
self
.
inference_service
.
server_metadata
()
def
_process_video_generation
(
self
,
message
:
TaskRequest
,
stop_event
:
threading
.
Event
):
def
_process_video_generation
(
self
,
message
:
TaskRequest
,
stop_event
:
threading
.
Event
):
assert
self
.
video_service
is
not
None
,
"Video service is not initialized"
assert
self
.
video_service
is
not
None
,
"Video service is not initialized"
try
:
try
:
...
...
lightx2v/server/service.py
View file @
ae089db4
...
@@ -186,6 +186,7 @@ class DistributedInferenceService:
...
@@ -186,6 +186,7 @@ class DistributedInferenceService:
self
.
is_running
=
False
self
.
is_running
=
False
def
start_distributed_inference
(
self
,
args
)
->
bool
:
def
start_distributed_inference
(
self
,
args
)
->
bool
:
self
.
args
=
args
if
self
.
is_running
:
if
self
.
is_running
:
logger
.
warning
(
"Distributed inference service is already running"
)
logger
.
warning
(
"Distributed inference service is already running"
)
return
True
return
True
...
@@ -311,6 +312,10 @@ class DistributedInferenceService:
...
@@ -311,6 +312,10 @@ class DistributedInferenceService:
return
None
return
None
def
server_metadata
(
self
):
assert
hasattr
(
self
,
"args"
),
"Distributed inference service has not been started. Call start_distributed_inference() first."
return
{
"nproc_per_node"
:
self
.
args
.
nproc_per_node
,
"model_cls"
:
self
.
args
.
model_cls
,
"model_path"
:
self
.
args
.
model_path
}
class
VideoGenerationService
:
class
VideoGenerationService
:
def
__init__
(
self
,
file_service
:
FileService
,
inference_service
:
DistributedInferenceService
):
def
__init__
(
self
,
file_service
:
FileService
,
inference_service
:
DistributedInferenceService
):
...
...
scripts/cache/readme.md
View file @
ae089db4
# Cache
# Feature Caching
## 缓存加速算法
-
在扩散模型的推理过程中,缓存复用是一种重要的加速算法。
-
其核心思想是在部分时间步跳过冗余计算,通过复用历史缓存结果提升推理效率。
-
算法的关键在于如何决策在哪些时间步进行缓存复用,通常基于模型状态变化或误差阈值动态判断。
-
在推理过程中,需要缓存如中间特征、残差、注意力输出等关键内容。当进入可复用时间步时,直接利用已缓存的内容,通过泰勒展开等近似方法重构当前输出,从而减少重复计算,实现高效推理。
## TeaCache
The config files for feature caching are available
[
here
](
https://github.com/ModelTC/lightx2v/tree/main/configs/caching
)
`TeaCache`
的核心思想是通过对相邻时间步输入的
**相对L1**
距离进行累加,当累计距离达到设定阈值时,判定当前时间步可以进行缓存复用。
-
具体来说,算法在每一步推理时计算当前输入与上一步输入的相对L1距离,并将其累加。
-
当累计距离超过阈值,说明模型状态发生了足够的变化,则直接复用最近一次缓存的内容,跳过部分冗余计算。这样可以显著减少模型的前向计算次数,提高推理速度。
实际效果上,TeaCache 在保证生成质量的前提下,实现了明显的加速。加速前后的视频对比如下:
By specifying --config_json to the specific config file, you can test different cache algorithms.
| 加速前 | 加速后 |
Please refer our feature caching doc:
|:------:|:------:|
| 单卡H200推理耗时:58s | 单卡H200推理耗时:17.9s |
| !
[
加速前效果
](
../../assets/gifs/1.gif
)
| !
[
加速后效果
](
../../assets/gifs/2.gif
)
|
-
加速比为:
**3.24**
-
参考论文:
[
https://arxiv.org/abs/2411.19108
](
https://arxiv.org/abs/2411.19108
)
## TaylorSeer Cache
[
English doc: Feature Caching
](
https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/cache.html
)
`TaylorSeer Cache`
的核心在于利用泰勒公式对缓存内容进行再次计算,作为缓存复用时间步的残差补偿。具体做法是在缓存复用的时间步,不仅简单地复用历史缓存,还通过泰勒展开对当前输出进行近似重构。这样可以在减少计算量的同时,进一步提升输出的准确性。泰勒展开能够有效捕捉模型状态的微小变化,使得缓存复用带来的误差得到补偿,从而在加速的同时保证生成质量。
`TaylorSeer Cache`
适用于对输出精度要求较高的场景,能够在缓存复用的基础上进一步提升模型推理的表现。
| 加速前 | 加速后 |
[
中文文档: 特征缓存
](
https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/cache.html
)
|:------:|:------:|
| 单卡H200推理耗时:57.7s | 单卡H200推理耗时:41.3s |
| !
[
加速前效果
](
../../assets/gifs/3.gif
)
| !
[
加速后效果
](
../../assets/gifs/4.gif
)
|
-
加速比为:
**1.39**
-
参考论文:
[
https://arxiv.org/abs/2503.06923
](
https://arxiv.org/abs/2503.06923
)
## AdaCache
`AdaCache`
的核心思想是根据指定block块中的部分缓存内容,动态调整缓存复用的步长。
-
算法会分析相邻两个时间步在特定 block 内的特征差异,根据差异大小自适应地决定下一个缓存复用的时间步间隔。
-
当模型状态变化较小时,步长自动加大,减少缓存更新频率;当状态变化较大时,步长缩小,保证输出质量。
这样可以根据实际推理过程中的动态变化,灵活调整缓存策略,实现更高效的加速和更优的生成效果。AdaCache 适合对推理速度和生成质量都有较高要求的应用场景。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:227s | 单卡H200推理耗时:83s |
| !
[
加速前效果
](
../../assets/gifs/5.gif
)
| !
[
加速后效果
](
../../assets/gifs/6.gif
)
|
-
加速比为:
**2.73**
-
参考论文:
[
https://arxiv.org/abs/2411.02397
](
https://arxiv.org/abs/2411.02397
)
## CustomCache
`CustomCache`
综合了
`TeaCache`
和
`TaylorSeer Cache`
的优势。
-
它结合了
`TeaCache`
在缓存决策上的实时性和合理性,通过动态阈值判断何时进行缓存复用.
-
同时利用
`TaylorSeer`
的泰勒展开方法对已缓存内容进行利用。
这样不仅能够高效地决定缓存复用的时机,还能最大程度地利用缓存内容,提升输出的准确性和生成质量。实际测试表明,
`CustomCache`
在多个内容生成任务上,生成的视频质量优于单独使用
`TeaCache、TaylorSeer Cache`
或
`AdaCache`
的方案,是目前综合性能最优的缓存加速算法之一。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:57.9s | 单卡H200推理耗时:16.6s |
| !
[
加速前效果
](
../../assets/gifs/7.gif
)
| !
[
加速后效果
](
../../assets/gifs/8.gif
)
|
-
加速比为:
**3.49**
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment