Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
e08c4f90
Commit
e08c4f90
authored
Jul 17, 2025
by
sandy
Committed by
GitHub
Jul 17, 2025
Browse files
Merge branch 'main' into audio_r2v
parents
12bfd120
6d07a72e
Changes
191
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
382 additions
and
135 deletions
+382
-135
docs/ZH_CN/source/method_tutorials/step_distill.md
docs/ZH_CN/source/method_tutorials/step_distill.md
+117
-1
lightx2v/api_server.py
lightx2v/api_server.py
+3
-1
lightx2v/common/apis/dit.py
lightx2v/common/apis/dit.py
+0
-1
lightx2v/common/apis/image_encoder.py
lightx2v/common/apis/image_encoder.py
+0
-1
lightx2v/common/apis/text_encoder.py
lightx2v/common/apis/text_encoder.py
+0
-1
lightx2v/common/apis/vae.py
lightx2v/common/apis/vae.py
+0
-1
lightx2v/common/offload/manager.py
lightx2v/common/offload/manager.py
+33
-11
lightx2v/common/ops/attn/attn_weight.py
lightx2v/common/ops/attn/attn_weight.py
+1
-1
lightx2v/common/ops/conv/conv2d.py
lightx2v/common/ops/conv/conv2d.py
+7
-0
lightx2v/common/ops/conv/conv3d.py
lightx2v/common/ops/conv/conv3d.py
+7
-0
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+1
-1
lightx2v/common/ops/norm/layer_norm_weight.py
lightx2v/common/ops/norm/layer_norm_weight.py
+5
-3
lightx2v/common/ops/norm/rms_norm_weight.py
lightx2v/common/ops/norm/rms_norm_weight.py
+5
-1
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+5
-1
lightx2v/infer.py
lightx2v/infer.py
+14
-22
lightx2v/models/input_encoders/hf/llava/model.py
lightx2v/models/input_encoders/hf/llava/model.py
+0
-9
lightx2v/models/input_encoders/hf/q_linear.py
lightx2v/models/input_encoders/hf/q_linear.py
+33
-12
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+91
-34
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+56
-33
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+4
-1
No files found.
docs/ZH_CN/source/method_tutorials/step_distill.md
View file @
e08c4f90
# 步数蒸馏
xxx
步数蒸馏是 LightX2V 中的一项重要优化技术,通过训练蒸馏模型将推理步数从原始的 40-50 步大幅减少到
**4 步**
,在保持视频质量的同时显著提升推理速度。LightX2V 在实现步数蒸馏的同时也加入了 CFG 蒸馏,进一步提升推理速度。
## 🔍 技术原理
步数蒸馏通过
[
Self-Forcing
](
https://github.com/guandeh17/Self-Forcing
)
技术实现。Self-Forcing 针对 1.3B 的自回归模型进行步数蒸馏、CFG蒸馏。LightX2V 在其基础上,进行了一系列扩展:
1.
**更大的模型**
:支持 14B 模型的步数蒸馏训练;
2.
**更多的模型**
:支持标准的双向模型,以及 I2V 模型的步数蒸馏训练;
具体实现可参考
[
Self-Forcing-Plus
](
https://github.com/GoatWu/Self-Forcing-Plus
)
。
## 🎯 技术特性
-
**推理加速**
:推理步数从 40-50 步减少到 4 步且无需 CFG,速度提升约
**20-24x**
-
**质量保持**
:通过蒸馏技术保持原有的视频生成质量
-
**兼容性强**
:支持 T2V 和 I2V 任务
-
**使用灵活**
:支持加载完整步数蒸馏模型,或者在原生模型的基础上加载步数蒸馏LoRA
## 🛠️ 配置文件说明
### 基础配置文件
在
[
configs/distill/
](
https://github.com/ModelTC/lightx2v/tree/main/configs/distill
)
目录下提供了多种配置选项:
| 配置文件 | 用途 | 模型地址 |
|----------|------|------------|
|
[
wan_t2v_distill_4step_cfg.json
](
https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg.json
)
| 加载 T2V 4步蒸馏完整模型 | TODO |
|
[
wan_i2v_distill_4step_cfg.json
](
https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg.json
)
| 加载 I2V 4步蒸馏完整模型 | TODO |
|
[
wan_t2v_distill_4step_cfg_lora.json
](
https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_t2v_distill_4step_cfg_lora.json
)
| 加载 Wan-T2V 模型和步数蒸馏 LoRA | TODO |
|
[
wan_i2v_distill_4step_cfg_lora.json
](
https://github.com/ModelTC/lightx2v/blob/main/configs/distill/wan_i2v_distill_4step_cfg_lora.json
)
| 加载 Wan-I2V 模型和步数蒸馏 LoRA | TODO |
### 关键配置参数
```
json
{
"infer_steps"
:
4
,
//
推理步数
"denoising_step_list"
:
[
999
,
750
,
500
,
250
],
//
去噪时间步列表
"enable_cfg"
:
false
,
//
关闭CFG以提升速度
"lora_configs"
:
[
//
LoRA权重路径(可选)
{
"path"
:
"path/to/distill_lora.safetensors"
,
"strength"
:
1.0
}
]
}
```
## 📜 使用方法
### 模型准备
**完整模型:**
将下载好的模型(
`distill_model.pt`
或者
`distill_model.safetensors`
)放到 Wan 模型根目录的
`distill_models/`
文件夹下即可
-
对于 T2V:
`Wan2.1-T2V-14B/distill_models/`
-
对于 I2V-480P:
`Wan2.1-I2V-14B-480P/distill_models/`
**LoRA:**
1.
将下载好的 LoRA 放到任意位置
2.
修改配置文件中的
`lora_path`
参数为 LoRA 存放路径即可
### 推理脚本
**T2V 完整模型:**
```
bash
bash scripts/wan/run_wan_t2v_distill_4step_cfg.sh
```
**I2V 完整模型:**
```
bash
bash scripts/wan/run_wan_i2v_distill_4step_cfg.sh
```
### 步数蒸馏 LoRA 推理脚本
**T2V LoRA:**
```
bash
bash scripts/wan/run_wan_t2v_distill_4step_cfg_lora.sh
```
**I2V LoRA:**
```
bash
bash scripts/wan/run_wan_i2v_distill_4step_cfg_lora.sh
```
## 🔧 服务化部署
### 启动蒸馏模型服务
对
[
scripts/server/start_server.sh
](
https://github.com/ModelTC/lightx2v/blob/main/scripts/server/start_server.sh
)
中的启动命令进行修改:
```
bash
python
-m
lightx2v.api_server
\
--model_cls
wan2.1_distill
\
--task
t2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/distill/wan_t2v_distill_4step_cfg.json
\
--port
8000
\
--nproc_per_node
1
```
运行服务启动脚本:
```
bash
scripts/server/start_server.sh
```
更多详细信息见
[
服务化部署
](
https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_service.html
)
。
### 在 Gradio 界面中使用
见
[
Gradio 文档
](
https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/deploy_guides/deploy_gradio.html
)
lightx2v/api_server.py
View file @
e08c4f90
...
...
@@ -36,6 +36,7 @@ def main():
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"wan2.1_audio"
,
...
...
@@ -48,6 +49,7 @@ def main():
parser
.
add_argument
(
"--split"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--lora_path"
,
type
=
str
,
required
=
False
,
default
=
None
)
parser
.
add_argument
(
"--lora_strength"
,
type
=
float
,
default
=
1.0
,
help
=
"The strength for the lora (default: 1.0)"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
)
parser
.
add_argument
(
"--nproc_per_node"
,
type
=
int
,
default
=
1
,
help
=
"Number of processes per node for distributed inference"
)
...
...
@@ -55,7 +57,7 @@ def main():
args
=
parser
.
parse_args
()
logger
.
info
(
f
"args:
{
args
}
"
)
cache_dir
=
Path
(
__file__
).
parent
.
parent
/
"
.
cache"
cache_dir
=
Path
(
__file__
).
parent
.
parent
/
"
server_
cache"
inference_service
=
DistributedInferenceService
()
api_server
=
ApiServer
()
...
...
lightx2v/common/apis/dit.py
View file @
e08c4f90
...
...
@@ -121,7 +121,6 @@ if __name__ == "__main__":
with
ProfilingContext
(
"Init Server Cost"
):
config
=
set_config
(
args
)
config
[
"mode"
]
=
"split_server"
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
runner
=
DiTRunner
(
config
)
...
...
lightx2v/common/apis/image_encoder.py
View file @
e08c4f90
...
...
@@ -116,7 +116,6 @@ if __name__ == "__main__":
with
ProfilingContext
(
"Init Server Cost"
):
config
=
set_config
(
args
)
config
[
"mode"
]
=
"split_server"
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
runner
=
ImageEncoderRunner
(
config
)
...
...
lightx2v/common/apis/text_encoder.py
View file @
e08c4f90
...
...
@@ -119,7 +119,6 @@ if __name__ == "__main__":
with
ProfilingContext
(
"Init Server Cost"
):
config
=
set_config
(
args
)
config
[
"mode"
]
=
"split_server"
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
runner
=
TextEncoderRunner
(
config
)
...
...
lightx2v/common/apis/vae.py
View file @
e08c4f90
...
...
@@ -168,7 +168,6 @@ if __name__ == "__main__":
with
ProfilingContext
(
"Init Server Cost"
):
config
=
set_config
(
args
)
config
[
"mode"
]
=
"split_server"
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
runner
=
VAERunner
(
config
)
...
...
lightx2v/common/offload/manager.py
View file @
e08c4f90
...
...
@@ -15,6 +15,7 @@ class WeightAsyncStreamManager(object):
self
.
cuda_load_stream
=
torch
.
cuda
.
Stream
(
priority
=
0
)
self
.
offload_block_num
=
int
(
offload_ratio
*
blocks_num
)
self
.
phases_num
=
phases_num
self
.
block_nums
=
blocks_num
self
.
offload_phases_num
=
blocks_num
*
phases_num
*
offload_ratio
def
prefetch_weights
(
self
,
block_idx
,
blocks_weights
):
...
...
@@ -121,12 +122,16 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
except
Exception
as
e
:
logger
.
error
(
f
"Disk worker thread error:
{
e
}
"
)
def
_async_prefetch_block
(
self
,
weights
):
next_block_idx
=
self
.
pin_memory_buffer
.
get_max_block_index
()
def
_async_prefetch_block
(
self
,
blocks
,
next_block_idx
=
None
):
if
next_block_idx
is
None
:
next_block_idx
=
self
.
pin_memory_buffer
.
get_max_block_index
()
if
next_block_idx
<
0
:
next_block_idx
=
0
if
next_block_idx
==
self
.
block_nums
:
return
if
self
.
offload_gra
==
"phase"
:
for
phase_idx
in
range
(
self
.
phases_num
):
obj_key
=
(
next_block_idx
,
phase_idx
)
...
...
@@ -137,7 +142,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
with
self
.
task_lock
:
self
.
pending_tasks
[
obj_key
]
=
True
phase
=
weights
.
blocks
[
next_block_idx
].
compute_phases
[
phase_idx
]
phase
=
blocks
[
next_block_idx
].
compute_phases
[
phase_idx
]
priority_key
=
(
next_block_idx
,
phase_idx
)
self
.
disk_task_queue
.
put
((
priority_key
,
(
next_block_idx
,
phase_idx
,
phase
)))
...
...
@@ -149,32 +154,34 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
with
self
.
task_lock
:
self
.
pending_tasks
[
obj_key
]
=
True
block
=
weights
.
blocks
[
next_block_idx
]
block
=
blocks
[
next_block_idx
]
self
.
disk_task_queue
.
put
((
obj_key
,
(
next_block_idx
,
block
)))
def
_sync_prefetch_block
(
self
,
weight
s
):
def
_sync_prefetch_block
(
self
,
block
s
):
block_idx
=
0
while
not
self
.
pin_memory_buffer
.
is_nearly_full
():
if
self
.
offload_gra
==
"phase"
:
for
phase_idx
in
range
(
self
.
phases_num
):
phase
=
weights
.
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
phase
=
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
logger
.
info
(
f
"Synchronous loading: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
phase
.
load_from_disk
()
self
.
pin_memory_buffer
.
push
((
block_idx
,
phase_idx
),
phase
)
else
:
block
=
weights
.
blocks
[
block_idx
]
block
=
blocks
[
block_idx
]
logger
.
info
(
f
"Synchronous loading: block=
{
block_idx
}
"
)
for
phase
in
block
.
compute_phases
:
phase
.
load_from_disk
()
self
.
pin_memory_buffer
.
push
(
block_idx
,
block
)
block_idx
+=
1
if
block_idx
==
self
.
block_nums
:
break
def
prefetch_weights_from_disk
(
self
,
weight
s
):
def
prefetch_weights_from_disk
(
self
,
block
s
):
if
self
.
initial_prefetch_done
:
return
self
.
_sync_prefetch_block
(
weight
s
)
self
.
_sync_prefetch_block
(
block
s
)
self
.
initial_prefetch_done
=
True
def
prefetch_weights
(
self
,
block_idx
,
blocks
):
...
...
@@ -193,7 +200,15 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if
time
.
time
()
-
start_time
>
5
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
"
)
else
:
logger
.
info
(
"Not find prefetch block={block_idx} task. This is a bug."
)
logger
.
info
(
"Not find prefetch block={block_idx} task."
)
logger
.
info
(
"Sync prefetch block={block_idx}."
)
self
.
_async_prefetch_block
(
blocks
,
block_idx
)
start_time
=
time
.
time
()
for
phase_idx
in
self
.
phases_num
:
while
not
self
.
pin_memory_buffer
.
exists
((
block_idx
,
phase_idx
)):
time
.
sleep
(
0.001
)
if
time
.
time
()
-
start_time
>
15
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
block
=
self
.
pin_memory_buffer
.
get
(
obj_key
)
...
...
@@ -224,7 +239,14 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if
time
.
time
()
-
start_time
>
5
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
else
:
logger
.
info
(
"Not find prefetch block={block_idx}, phase={phase_idx} task. This is a bug."
)
logger
.
info
(
f
"Not find block=
{
block_idx
}
, phase=
{
phase_idx
}
task."
)
logger
.
info
(
f
"Sync prefetch block=
{
block_idx
}
, phase=
{
phase_idx
}
."
)
self
.
_async_prefetch_block
(
blocks
,
block_idx
)
start_time
=
time
.
time
()
while
not
self
.
pin_memory_buffer
.
exists
((
block_idx
,
phase_idx
)):
time
.
sleep
(
0.001
)
if
time
.
time
()
-
start_time
>
5
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
phase
=
self
.
pin_memory_buffer
.
get
(
obj_key
)
...
...
lightx2v/common/ops/attn/attn_weight.py
View file @
e08c4f90
...
...
@@ -23,7 +23,7 @@ except ImportError:
logger
.
info
(
"flash_attn_varlen_func_v3 not found, please install flash_attn3 first"
)
flash_attn_varlen_func_v3
=
None
if
torch
.
cuda
.
get_device_capability
(
0
)
[
0
]
<=
8
and
torch
.
cuda
.
get_device_capability
(
0
)[
1
]
<=
9
:
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
8
,
9
)
:
try
:
from
sageattention
import
sageattn_qk_int8_pv_fp16_triton
as
sageattn
except
ImportError
:
...
...
lightx2v/common/ops/conv/conv2d.py
View file @
e08c4f90
...
...
@@ -56,3 +56,10 @@ class Conv2dWeight(Conv2dWeightTemplate):
if
self
.
bias
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
bias
.
cpu
().
detach
().
clone
()
return
destination
def
clear
(
self
):
attrs
=
[
"weight"
,
"bias"
,
"pinned_weight"
,
"pinned_bias"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
lightx2v/common/ops/conv/conv3d.py
View file @
e08c4f90
...
...
@@ -66,3 +66,10 @@ class Conv3dWeight(Conv3dWeightTemplate):
if
self
.
bias
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
bias
.
cpu
().
detach
().
clone
()
return
destination
def
clear
(
self
):
attrs
=
[
"weight"
,
"bias"
,
"pinned_weight"
,
"pinned_bias"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
lightx2v/common/ops/mm/mm_weight.py
View file @
e08c4f90
...
...
@@ -145,7 +145,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
pinned_weight
=
self
.
pinned_weight
.
t
()
def
clear
(
self
):
attrs
=
[
"weight"
,
"weight_scale"
,
"bias"
]
attrs
=
[
"weight"
,
"weight_scale"
,
"bias"
,
"pinned_weight"
,
"pinned_weight_scale"
,
"pinned_bias"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
...
...
lightx2v/common/ops/norm/layer_norm_weight.py
View file @
e08c4f90
...
...
@@ -34,9 +34,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
def
clear
(
self
):
del
self
.
weight
if
self
.
bias
is
not
None
:
del
self
.
bias
attrs
=
[
"weight"
,
"bias"
,
"pinned_weight"
,
"pinned_bias"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
@
abstractmethod
def
apply
(
self
,
input_tensor
):
...
...
lightx2v/common/ops/norm/rms_norm_weight.py
View file @
e08c4f90
...
...
@@ -23,7 +23,11 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self
.
pinned_weight
=
torch
.
empty
(
self
.
weight
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
weight
.
dtype
)
def
clear
(
self
):
del
self
.
weight
attrs
=
[
"weight"
,
"pinned_weight"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
@
abstractmethod
def
apply
(
self
,
input_tensor
):
...
...
lightx2v/common/ops/tensor/tensor.py
View file @
e08c4f90
...
...
@@ -22,7 +22,11 @@ class DefaultTensor:
self
.
pinned_tensor
=
torch
.
empty
(
self
.
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
tensor
.
dtype
)
def
clear
(
self
):
del
self
.
tensor
attrs
=
[
"tensor"
,
"pinned_tensor"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
def
_calculate_size
(
self
):
return
self
.
tensor
.
numel
()
*
self
.
tensor
.
element_size
()
...
...
lightx2v/infer.py
View file @
e08c4f90
import
asyncio
import
argparse
import
torch
import
torch.distributed
as
dist
...
...
@@ -40,11 +39,12 @@ def init_runner(config):
return
runner
async
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
],
default
=
"
hunyuan
"
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
],
default
=
"
wan2.1
"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
...
...
@@ -52,35 +52,27 @@ async def main():
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
""
,
help
=
"The input prompt for text-to-video generation"
)
parser
.
add_argument
(
"--negative_prompt"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--lora_path"
,
type
=
str
,
default
=
""
,
help
=
"The lora file path"
)
parser
.
add_argument
(
"--
prompt
_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input
prompt file
"
)
parser
.
add_argument
(
"--audio_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input audio file"
)
parser
.
add_argument
(
"--image_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input image file or path for image-to-video (i2v) task"
)
parser
.
add_argument
(
"--
image
_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input
image file for image-to-video (i2v) task
"
)
parser
.
add_argument
(
"--audio_path"
,
type
=
str
,
default
=
""
,
help
=
"The path to input audio file
for audio-to-video (a2v) task
"
)
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
"./output_lightx2v.mp4"
,
help
=
"The path to save video path/file"
)
args
=
parser
.
parse_args
()
if
args
.
prompt_path
:
try
:
with
open
(
args
.
prompt_path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
args
.
prompt
=
f
.
read
().
strip
()
logger
.
info
(
f
"从文件
{
args
.
prompt_path
}
读取到prompt:
{
args
.
prompt
}
"
)
except
FileNotFoundError
:
logger
.
error
(
f
"找不到prompt文件:
{
args
.
prompt_path
}
"
)
raise
except
Exception
as
e
:
logger
.
error
(
f
"读取prompt文件时出错:
{
e
}
"
)
raise
logger
.
info
(
f
"args:
{
args
}
"
)
with
ProfilingContext
(
"Total Cost"
):
config
=
set_config
(
args
)
config
[
"mode"
]
=
"infer"
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
runner
=
init_runner
(
config
)
await
runner
.
run_pipeline
()
runner
.
run_pipeline
()
# Clean up distributed process group
if
dist
.
is_initialized
():
dist
.
destroy_process_group
()
logger
.
info
(
"Distributed process group cleaned up"
)
if
__name__
==
"__main__"
:
asyncio
.
run
(
main
()
)
main
()
lightx2v/models/input_encoders/hf/llava/model.py
View file @
e08c4f90
...
...
@@ -151,12 +151,3 @@ class TextEncoderHFLlavaModel:
if
config
.
cpu_offload
:
self
.
to_cpu
()
return
last_hidden_state
,
attention_mask
if
__name__
==
"__main__"
:
model
=
TextEncoderHFLlavaModel
(
"/mtc/yongyang/models/x2v_models/hunyuan/lightx2v_format/i2v/text_encoder_i2v"
,
torch
.
device
(
"cuda"
))
text
=
"An Asian man with short hair in black tactical uniform and white clothes waves a firework stick."
img_path
=
"/mtc/yongyang/projects/lightx2v/assets/inputs/imgs/img_1.jpg"
img
=
Image
.
open
(
img_path
).
convert
(
"RGB"
)
outputs
=
model
.
infer
(
text
,
img
,
None
)
logger
.
info
(
outputs
)
lightx2v/models/input_encoders/hf/q_linear.py
View file @
e08c4f90
...
...
@@ -2,14 +2,9 @@ import torch
import
torch.nn
as
nn
from
vllm
import
_custom_ops
as
ops
try
:
import
q8_kernels.functional
as
Q8F
except
ImportError
:
Q8F
=
None
class
QuantLinearInt8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
...
...
@@ -18,7 +13,7 @@ class QuantLinearInt8(nn.Module):
self
.
register_buffer
(
"weight_scale"
,
torch
.
empty
((
out_features
,
1
),
dtype
=
torch
.
float32
))
if
bias
:
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
dtype
))
else
:
self
.
register_buffer
(
"bias"
,
None
)
...
...
@@ -44,18 +39,30 @@ class QuantLinearInt8(nn.Module):
)
return
output_tensor
.
unsqueeze
(
0
)
def
_apply
(
self
,
fn
):
for
module
in
self
.
children
():
module
.
_apply
(
fn
)
def
maybe_cast
(
t
):
if
t
is
not
None
and
t
.
device
!=
fn
(
t
).
device
:
return
fn
(
t
)
return
t
self
.
weight
=
maybe_cast
(
self
.
weight
)
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
return
self
class
QuantLinearFp8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
register_buffer
(
"weight"
,
torch
.
empty
((
out_features
,
in_features
),
dtype
=
torch
.
float8_e4m3fn
))
self
.
register_buffer
(
"weight_scale"
,
torch
.
empty
((
out_features
,
1
),
dtype
=
torch
.
float32
))
if
bias
:
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
dtype
))
else
:
self
.
register_buffer
(
"bias"
,
None
)
...
...
@@ -65,7 +72,6 @@ class QuantLinearFp8(nn.Module):
def
forward
(
self
,
input_tensor
):
input_tensor
=
input_tensor
.
squeeze
(
0
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float8_e4m3fn
)
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
0
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
...
...
@@ -79,4 +85,19 @@ class QuantLinearFp8(nn.Module):
self
.
weight_scale
.
float
(),
self
.
bias
,
)
return
output_tensor
.
unsqueeze
(
0
)
def
_apply
(
self
,
fn
):
for
module
in
self
.
children
():
module
.
_apply
(
fn
)
def
maybe_cast
(
t
):
if
t
is
not
None
and
t
.
device
!=
fn
(
t
).
device
:
return
fn
(
t
)
return
t
self
.
weight
=
maybe_cast
(
self
.
weight
)
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
return
self
lightx2v/models/input_encoders/hf/t5/model.py
View file @
e08c4f90
...
...
@@ -27,6 +27,14 @@ def fp16_clamp(x):
return
x
def
optimize_memory_usage
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
import
gc
gc
.
collect
()
def
init_weights
(
m
):
if
isinstance
(
m
,
T5LayerNorm
):
nn
.
init
.
ones_
(
m
.
weight
)
...
...
@@ -51,11 +59,11 @@ class GELU(nn.Module):
class
T5LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-6
):
def
__init__
(
self
,
dim
,
eps
=
1e-6
,
dtype
=
torch
.
float16
):
super
(
T5LayerNorm
,
self
).
__init__
()
self
.
dim
=
dim
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
dim
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
dim
,
dtype
=
dtype
))
def
forward
(
self
,
x
):
x
=
x
*
torch
.
rsqrt
(
x
.
float
().
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
...
...
@@ -65,7 +73,7 @@ class T5LayerNorm(nn.Module):
class
T5Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_attn
,
num_heads
,
dropout
=
0.1
,
quantized
=
False
,
quant_scheme
=
None
):
def
__init__
(
self
,
dim
,
dim_attn
,
num_heads
,
dropout
=
0.1
,
quantized
=
False
,
quant_scheme
=
None
,
dtype
=
torch
.
bfloat16
):
assert
dim_attn
%
num_heads
==
0
super
(
T5Attention
,
self
).
__init__
()
self
.
dim
=
dim
...
...
@@ -82,10 +90,10 @@ class T5Attention(nn.Module):
linear_cls
=
nn
.
Linear
# layers
self
.
q
=
linear_cls
(
dim
,
dim_attn
,
bias
=
False
)
self
.
k
=
linear_cls
(
dim
,
dim_attn
,
bias
=
False
)
self
.
v
=
linear_cls
(
dim
,
dim_attn
,
bias
=
False
)
self
.
o
=
linear_cls
(
dim_attn
,
dim
,
bias
=
False
)
self
.
q
=
linear_cls
(
dim
,
dim_attn
,
bias
=
False
,
dtype
=
dtype
)
self
.
k
=
linear_cls
(
dim
,
dim_attn
,
bias
=
False
,
dtype
=
dtype
)
self
.
v
=
linear_cls
(
dim
,
dim_attn
,
bias
=
False
,
dtype
=
dtype
)
self
.
o
=
linear_cls
(
dim_attn
,
dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
,
pos_bias
=
None
):
...
...
@@ -114,10 +122,14 @@ class T5Attention(nn.Module):
# compute attention (T5 does not use scaling)
attn
=
torch
.
einsum
(
"binc,bjnc->bnij"
,
q
,
k
)
+
attn_bias
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
del
attn_bias
attn
=
F
.
softmax
(
attn
.
float
(),
dim
=-
1
).
to
(
torch
.
bfloat16
)
x
=
torch
.
einsum
(
"bnij,bjnc->binc"
,
attn
,
v
)
# output
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
del
attn
x
=
x
.
reshape
(
b
,
-
1
,
n
*
c
)
x
=
self
.
o
(
x
)
x
=
self
.
dropout
(
x
)
...
...
@@ -125,7 +137,7 @@ class T5Attention(nn.Module):
class
T5FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_ffn
,
dropout
=
0.1
,
quantized
=
False
,
quant_scheme
=
None
):
def
__init__
(
self
,
dim
,
dim_ffn
,
dropout
=
0.1
,
quantized
=
False
,
quant_scheme
=
None
,
dtype
=
torch
.
bfloat16
):
super
(
T5FeedForward
,
self
).
__init__
()
self
.
dim
=
dim
self
.
dim_ffn
=
dim_ffn
...
...
@@ -138,13 +150,20 @@ class T5FeedForward(nn.Module):
else
:
linear_cls
=
nn
.
Linear
# layers
self
.
gate
=
nn
.
Sequential
(
linear_cls
(
dim
,
dim_ffn
,
bias
=
False
),
GELU
())
self
.
fc1
=
linear_cls
(
dim
,
dim_ffn
,
bias
=
False
)
self
.
fc2
=
linear_cls
(
dim_ffn
,
dim
,
bias
=
False
)
self
.
gate
=
nn
.
Sequential
(
linear_cls
(
dim
,
dim_ffn
,
bias
=
False
,
dtype
=
dtype
),
GELU
())
self
.
fc1
=
linear_cls
(
dim
,
dim_ffn
,
bias
=
False
,
dtype
=
dtype
)
self
.
fc2
=
linear_cls
(
dim_ffn
,
dim
,
bias
=
False
,
dtype
=
dtype
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
*
self
.
gate
(
x
)
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
gate_out
=
self
.
gate
(
x
)
fc1_out
=
self
.
fc1
(
x
)
x
=
fc1_out
*
gate_out
del
gate_out
,
fc1_out
else
:
x
=
self
.
fc1
(
x
)
*
self
.
gate
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout
(
x
)
...
...
@@ -152,7 +171,7 @@ class T5FeedForward(nn.Module):
class
T5SelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
=
True
,
dropout
=
0.1
,
quantized
=
False
,
quant_scheme
=
None
):
def
__init__
(
self
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
=
True
,
dropout
=
0.1
,
quantized
=
False
,
quant_scheme
=
None
,
dtype
=
torch
.
bfloat16
):
super
(
T5SelfAttention
,
self
).
__init__
()
self
.
dim
=
dim
self
.
dim_attn
=
dim_attn
...
...
@@ -162,16 +181,27 @@ class T5SelfAttention(nn.Module):
self
.
shared_pos
=
shared_pos
# layers
self
.
norm1
=
T5LayerNorm
(
dim
)
self
.
attn
=
T5Attention
(
dim
,
dim_attn
,
num_heads
,
dropout
,
quantized
,
quant_scheme
)
self
.
norm2
=
T5LayerNorm
(
dim
)
self
.
ffn
=
T5FeedForward
(
dim
,
dim_ffn
,
dropout
,
quantized
,
quant_scheme
)
self
.
pos_embedding
=
None
if
shared_pos
else
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
True
)
self
.
norm1
=
T5LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
attn
=
T5Attention
(
dim
,
dim_attn
,
num_heads
,
dropout
,
quantized
,
quant_scheme
,
dtype
)
self
.
norm2
=
T5LayerNorm
(
dim
,
dtype
=
dtype
)
self
.
ffn
=
T5FeedForward
(
dim
,
dim_ffn
,
dropout
,
quantized
,
quant_scheme
,
dtype
=
dtype
)
self
.
pos_embedding
=
None
if
shared_pos
else
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
True
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
mask
=
None
,
pos_bias
=
None
):
e
=
pos_bias
if
self
.
shared_pos
else
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
x
=
fp16_clamp
(
x
+
self
.
attn
(
self
.
norm1
(
x
),
mask
=
mask
,
pos_bias
=
e
))
x
=
fp16_clamp
(
x
+
self
.
ffn
(
self
.
norm2
(
x
)))
if
hasattr
(
self
,
"cpu_offload"
)
and
self
.
cpu_offload
:
attn_out
=
self
.
attn
(
self
.
norm1
(
x
),
mask
=
mask
,
pos_bias
=
e
)
x
=
fp16_clamp
(
x
+
attn_out
)
del
attn_out
ffn_out
=
self
.
ffn
(
self
.
norm2
(
x
))
x
=
fp16_clamp
(
x
+
ffn_out
)
del
ffn_out
else
:
x
=
fp16_clamp
(
x
+
self
.
attn
(
self
.
norm1
(
x
),
mask
=
mask
,
pos_bias
=
e
))
x
=
fp16_clamp
(
x
+
self
.
ffn
(
self
.
norm2
(
x
)))
return
x
...
...
@@ -212,7 +242,7 @@ class T5CrossAttention(nn.Module):
class
T5RelativeEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
num_buckets
,
num_heads
,
bidirectional
,
max_dist
=
128
):
def
__init__
(
self
,
num_buckets
,
num_heads
,
bidirectional
,
dtype
=
torch
.
bfloat16
,
max_dist
=
128
):
super
(
T5RelativeEmbedding
,
self
).
__init__
()
self
.
num_buckets
=
num_buckets
self
.
num_heads
=
num_heads
...
...
@@ -220,7 +250,7 @@ class T5RelativeEmbedding(nn.Module):
self
.
max_dist
=
max_dist
# layers
self
.
embedding
=
nn
.
Embedding
(
num_buckets
,
num_heads
)
self
.
embedding
=
nn
.
Embedding
(
num_buckets
,
num_heads
,
dtype
=
dtype
)
def
forward
(
self
,
lq
,
lk
):
device
=
self
.
embedding
.
weight
.
device
...
...
@@ -252,7 +282,7 @@ class T5RelativeEmbedding(nn.Module):
class
T5Encoder
(
nn
.
Module
):
def
__init__
(
self
,
vocab
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_layers
,
num_buckets
,
shared_pos
=
True
,
dropout
=
0.1
,
cpu_offload
=
False
,
quantized
=
False
,
quant_scheme
=
None
):
def
__init__
(
self
,
dtype
,
vocab
,
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_layers
,
num_buckets
,
shared_pos
=
True
,
dropout
=
0.1
,
cpu_offload
=
False
,
quantized
=
False
,
quant_scheme
=
None
):
super
(
T5Encoder
,
self
).
__init__
()
self
.
cpu_offload
=
cpu_offload
...
...
@@ -266,11 +296,17 @@ class T5Encoder(nn.Module):
self
.
quant_scheme
=
quant_scheme
# layers
self
.
token_embedding
=
vocab
if
isinstance
(
vocab
,
nn
.
Embedding
)
else
nn
.
Embedding
(
vocab
,
dim
)
self
.
pos_embedding
=
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
True
)
if
shared_pos
else
None
self
.
token_embedding
=
vocab
.
to
(
dtype
)
if
isinstance
(
vocab
,
nn
.
Embedding
)
else
nn
.
Embedding
(
vocab
,
dim
,
dtype
=
dtype
)
self
.
pos_embedding
=
T5RelativeEmbedding
(
num_buckets
,
num_heads
,
bidirectional
=
True
,
dtype
=
dtype
)
if
shared_pos
else
None
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
blocks
=
nn
.
ModuleList
([
T5SelfAttention
(
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
,
dropout
,
quantized
,
quant_scheme
)
for
_
in
range
(
num_layers
)])
self
.
norm
=
T5LayerNorm
(
dim
)
self
.
blocks
=
nn
.
ModuleList
([
T5SelfAttention
(
dim
,
dim_attn
,
dim_ffn
,
num_heads
,
num_buckets
,
shared_pos
,
dropout
,
quantized
,
quant_scheme
,
dtype
)
for
_
in
range
(
num_layers
)])
if
cpu_offload
:
for
block
in
self
.
blocks
:
block
.
cpu_offload
=
cpu_offload
block
.
attn
.
cpu_offload
=
cpu_offload
block
.
ffn
.
cpu_offload
=
cpu_offload
self
.
norm
=
T5LayerNorm
(
dim
,
dtype
=
dtype
)
# initialize weights
# self.apply(init_weights)
...
...
@@ -281,23 +317,32 @@ class T5Encoder(nn.Module):
x
=
self
.
token_embedding
(
ids
)
if
self
.
cpu_offload
:
self
.
token_embedding
=
self
.
token_embedding
.
cpu
()
optimize_memory_usage
()
x
=
self
.
dropout
(
x
)
if
self
.
cpu_offload
and
self
.
pos_embedding
is
not
None
:
self
.
pos_embedding
=
self
.
pos_embedding
.
cuda
()
e
=
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
if
self
.
shared_pos
else
None
if
self
.
cpu_offload
and
self
.
pos_embedding
is
not
None
:
self
.
pos_embedding
=
self
.
pos_embedding
.
cpu
()
for
block
in
self
.
blocks
:
optimize_memory_usage
()
for
i
,
block
in
enumerate
(
self
.
blocks
):
if
self
.
cpu_offload
:
block
=
block
.
cuda
()
x
=
block
(
x
,
mask
,
pos_bias
=
e
)
if
self
.
cpu_offload
:
block
=
block
.
cpu
()
del
block
optimize_memory_usage
()
if
self
.
cpu_offload
:
self
.
norm
=
self
.
norm
.
cuda
()
x
=
self
.
norm
(
x
)
if
self
.
cpu_offload
:
self
.
norm
=
self
.
norm
.
cpu
()
optimize_memory_usage
()
x
=
self
.
dropout
(
x
)
return
x
.
to
(
torch
.
bfloat16
)
...
...
@@ -443,10 +488,10 @@ def _t5(
# init model
with
torch
.
device
(
device
):
model
=
model_cls
(
**
kwargs
)
model
=
model_cls
(
dtype
=
dtype
,
**
kwargs
)
# set device
model
=
model
.
to
(
dtype
=
dtype
,
device
=
device
)
model
=
model
.
to
(
device
=
device
)
return
model
...
...
@@ -511,9 +556,10 @@ class T5EncoderModel:
.
requires_grad_
(
False
)
)
logger
.
info
(
f
"Loading weights from
{
self
.
checkpoint_path
}
"
)
logger
.
info
(
f
"Start Loading weights from
{
self
.
checkpoint_path
}
"
)
model
.
load_state_dict
(
torch
.
load
(
self
.
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
))
logger
.
info
(
f
"End Loading weights from
{
self
.
checkpoint_path
}
"
)
self
.
model
=
model
if
shard_fn
is
not
None
:
self
.
model
=
shard_fn
(
self
.
model
,
sync_module_states
=
False
)
...
...
@@ -528,6 +574,10 @@ class T5EncoderModel:
def
to_cuda
(
self
):
self
.
model
=
self
.
model
.
to
(
"cuda"
)
def
optimize_memory
(
self
):
"""优化内存使用"""
optimize_memory_usage
()
def
infer
(
self
,
texts
):
if
self
.
cpu_offload
and
self
.
offload_granularity
==
"model"
:
self
.
to_cuda
()
...
...
@@ -536,10 +586,17 @@ class T5EncoderModel:
ids
=
ids
.
cuda
()
mask
=
mask
.
cuda
()
seq_lens
=
mask
.
gt
(
0
).
sum
(
dim
=
1
).
long
()
context
=
self
.
model
(
ids
,
mask
)
with
torch
.
no_grad
():
context
=
self
.
model
(
ids
,
mask
)
if
self
.
cpu_offload
and
self
.
offload_granularity
==
"model"
:
self
.
to_cpu
()
optimize_memory_usage
()
del
ids
,
mask
if
self
.
cpu_offload
:
optimize_memory_usage
()
return
[
u
[:
v
]
for
u
,
v
in
zip
(
context
,
seq_lens
)]
...
...
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
e08c4f90
...
...
@@ -50,7 +50,7 @@ class LayerNorm(nn.LayerNorm):
class
SelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
causal
=
False
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
quantized
=
False
,
quant_scheme
=
None
):
def
__init__
(
self
,
dim
,
num_heads
,
causal
=
False
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
quantized
=
False
,
quant_scheme
=
None
,
dtype
=
None
):
assert
dim
%
num_heads
==
0
super
().
__init__
()
self
.
dim
=
dim
...
...
@@ -69,8 +69,8 @@ class SelfAttention(nn.Module):
else
:
linear_cls
=
nn
.
Linear
self
.
to_qkv
=
linear_cls
(
dim
,
dim
*
3
)
self
.
proj
=
linear_cls
(
dim
,
dim
)
self
.
to_qkv
=
linear_cls
(
dim
,
dim
*
3
,
dtype
=
dtype
)
self
.
proj
=
linear_cls
(
dim
,
dim
,
dtype
=
dtype
)
def
forward
(
self
,
x
):
"""
...
...
@@ -108,7 +108,21 @@ class SwiGLU(nn.Module):
class
AttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
mlp_ratio
,
num_heads
,
post_norm
=
False
,
causal
=
False
,
activation
=
"quick_gelu"
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
norm_eps
=
1e-5
,
quantized
=
False
,
quant_scheme
=
None
):
def
__init__
(
self
,
dim
,
mlp_ratio
,
num_heads
,
post_norm
=
False
,
causal
=
False
,
activation
=
"quick_gelu"
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
norm_eps
=
1e-5
,
quantized
=
False
,
quant_scheme
=
None
,
dtype
=
torch
.
float16
,
):
assert
activation
in
[
"quick_gelu"
,
"gelu"
,
"swi_glu"
]
super
().
__init__
()
self
.
dim
=
dim
...
...
@@ -127,13 +141,18 @@ class AttentionBlock(nn.Module):
else
:
linear_cls
=
nn
.
Linear
self
.
norm1
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
self
.
attn
=
SelfAttention
(
dim
,
num_heads
,
causal
,
attn_dropout
,
proj_dropout
,
quantized
,
quant_scheme
)
self
.
norm2
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
self
.
norm1
=
LayerNorm
(
dim
,
eps
=
norm_eps
,
dtype
=
dtype
)
self
.
attn
=
SelfAttention
(
dim
,
num_heads
,
causal
,
attn_dropout
,
proj_dropout
,
quantized
,
quant_scheme
,
dtype
)
self
.
norm2
=
LayerNorm
(
dim
,
eps
=
norm_eps
,
dtype
=
dtype
)
if
activation
==
"swi_glu"
:
self
.
mlp
=
SwiGLU
(
dim
,
int
(
dim
*
mlp_ratio
))
self
.
mlp
=
SwiGLU
(
dim
,
int
(
dim
*
mlp_ratio
)
,
dtype
=
dtype
)
else
:
self
.
mlp
=
nn
.
Sequential
(
linear_cls
(
dim
,
int
(
dim
*
mlp_ratio
)),
QuickGELU
()
if
activation
==
"quick_gelu"
else
nn
.
GELU
(),
linear_cls
(
int
(
dim
*
mlp_ratio
),
dim
),
nn
.
Dropout
(
proj_dropout
))
self
.
mlp
=
nn
.
Sequential
(
linear_cls
(
dim
,
int
(
dim
*
mlp_ratio
),
dtype
=
dtype
),
QuickGELU
()
if
activation
==
"quick_gelu"
else
nn
.
GELU
(),
linear_cls
(
int
(
dim
*
mlp_ratio
),
dim
,
dtype
=
dtype
),
nn
.
Dropout
(
proj_dropout
),
)
def
forward
(
self
,
x
):
if
self
.
post_norm
:
...
...
@@ -146,7 +165,7 @@ class AttentionBlock(nn.Module):
class
AttentionPool
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
mlp_ratio
,
num_heads
,
activation
=
"gelu"
,
proj_dropout
=
0.0
,
norm_eps
=
1e-5
):
def
__init__
(
self
,
dim
,
mlp_ratio
,
num_heads
,
activation
=
"gelu"
,
proj_dropout
=
0.0
,
norm_eps
=
1e-5
,
dtype
=
torch
.
float16
):
assert
dim
%
num_heads
==
0
super
().
__init__
()
self
.
dim
=
dim
...
...
@@ -159,11 +178,13 @@ class AttentionPool(nn.Module):
# layers
gain
=
1.0
/
math
.
sqrt
(
dim
)
self
.
cls_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
1
,
dim
))
self
.
to_q
=
nn
.
Linear
(
dim
,
dim
)
self
.
to_kv
=
nn
.
Linear
(
dim
,
dim
*
2
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
int
(
dim
*
mlp_ratio
)),
QuickGELU
()
if
activation
==
"quick_gelu"
else
nn
.
GELU
(),
nn
.
Linear
(
int
(
dim
*
mlp_ratio
),
dim
),
nn
.
Dropout
(
proj_dropout
))
self
.
to_q
=
nn
.
Linear
(
dim
,
dim
,
dtype
=
dtype
)
self
.
to_kv
=
nn
.
Linear
(
dim
,
dim
*
2
,
dtype
=
dtype
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
,
dtype
=
dtype
)
self
.
norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
,
dtype
=
dtype
)
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
int
(
dim
*
mlp_ratio
),
dtype
=
dtype
),
QuickGELU
()
if
activation
==
"quick_gelu"
else
nn
.
GELU
(),
nn
.
Linear
(
int
(
dim
*
mlp_ratio
),
dim
,
dtype
=
dtype
),
nn
.
Dropout
(
proj_dropout
)
)
def
forward
(
self
,
x
):
"""
...
...
@@ -191,6 +212,7 @@ class AttentionPool(nn.Module):
class
VisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
dtype
=
torch
.
float16
,
image_size
=
224
,
patch_size
=
16
,
dim
=
768
,
...
...
@@ -228,26 +250,26 @@ class VisionTransformer(nn.Module):
# embeddings
gain
=
1.0
/
math
.
sqrt
(
dim
)
self
.
patch_embedding
=
nn
.
Conv2d
(
3
,
dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
not
pre_norm
)
self
.
patch_embedding
=
nn
.
Conv2d
(
3
,
dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
not
pre_norm
,
dtype
=
dtype
)
if
pool_type
in
(
"token"
,
"token_fc"
):
self
.
cls_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
1
,
dim
))
self
.
pos_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
self
.
num_patches
+
(
1
if
pool_type
in
(
"token"
,
"token_fc"
)
else
0
),
dim
))
self
.
cls_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
1
,
dim
,
dtype
=
dtype
))
self
.
pos_embedding
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
1
,
self
.
num_patches
+
(
1
if
pool_type
in
(
"token"
,
"token_fc"
)
else
0
),
dim
,
dtype
=
dtype
))
self
.
dropout
=
nn
.
Dropout
(
embedding_dropout
)
# transformer
self
.
pre_norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
if
pre_norm
else
None
self
.
pre_norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
,
dtype
=
dtype
)
if
pre_norm
else
None
self
.
transformer
=
nn
.
Sequential
(
*
[
AttentionBlock
(
dim
,
mlp_ratio
,
num_heads
,
post_norm
,
False
,
activation
,
attn_dropout
,
proj_dropout
,
norm_eps
,
quantized
,
quant_scheme
)
for
_
in
range
(
num_layers
)]
*
[
AttentionBlock
(
dim
,
mlp_ratio
,
num_heads
,
post_norm
,
False
,
activation
,
attn_dropout
,
proj_dropout
,
norm_eps
,
quantized
,
quant_scheme
,
dtype
)
for
_
in
range
(
num_layers
)]
)
self
.
post_norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
)
self
.
post_norm
=
LayerNorm
(
dim
,
eps
=
norm_eps
,
dtype
=
dtype
)
# head
if
pool_type
==
"token"
:
self
.
head
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
dim
,
out_dim
))
self
.
head
=
nn
.
Parameter
(
gain
*
torch
.
randn
(
dim
,
out_dim
,
dtype
=
dtype
))
elif
pool_type
==
"token_fc"
:
self
.
head
=
nn
.
Linear
(
dim
,
out_dim
)
self
.
head
=
nn
.
Linear
(
dim
,
out_dim
,
dtype
=
dtype
)
elif
pool_type
==
"attn_pool"
:
self
.
head
=
AttentionPool
(
dim
,
mlp_ratio
,
num_heads
,
activation
,
proj_dropout
,
norm_eps
)
self
.
head
=
AttentionPool
(
dim
,
mlp_ratio
,
num_heads
,
activation
,
proj_dropout
,
norm_eps
,
dtype
=
dtype
)
def
forward
(
self
,
x
,
interpolation
=
False
,
use_31_block
=
False
):
b
=
x
.
size
(
0
)
...
...
@@ -276,6 +298,7 @@ class VisionTransformer(nn.Module):
class
XLMRobertaCLIP
(
nn
.
Module
):
def
__init__
(
self
,
dtype
=
torch
.
float16
,
embed_dim
=
1024
,
image_size
=
224
,
patch_size
=
14
,
...
...
@@ -317,6 +340,7 @@ class XLMRobertaCLIP(nn.Module):
# models
self
.
visual
=
VisionTransformer
(
dtype
=
dtype
,
image_size
=
image_size
,
patch_size
=
patch_size
,
dim
=
vision_dim
,
...
...
@@ -341,12 +365,11 @@ class XLMRobertaCLIP(nn.Module):
def
_clip
(
pretrained
=
False
,
pretrained_name
=
None
,
model_cls
=
XLMRobertaCLIP
,
return_transforms
=
False
,
return_tokenizer
=
False
,
tokenizer_padding
=
"eos"
,
dtype
=
torch
.
float32
,
device
=
"cpu"
,
**
kwargs
):
# init a model on device
with
torch
.
device
(
device
):
model
=
model_cls
(
**
kwargs
)
model
=
model_cls
(
dtype
=
dtype
,
**
kwargs
)
# set device
model
=
model
.
to
(
dtype
=
dtype
,
device
=
device
)
output
=
(
model
,)
model
=
model
.
to
(
device
=
device
)
output
=
(
model
,)
# init transforms
if
return_transforms
:
# mean and std
...
...
@@ -395,23 +418,23 @@ class CLIPModel:
else
:
self
.
checkpoint_path
=
checkpoint_path
logger
.
info
(
f
"Loading weights from
{
self
.
checkpoint_path
}
"
)
# init model
self
.
model
,
self
.
transforms
=
clip_xlm_roberta_vit_h_14
(
pretrained
=
False
,
return_transforms
=
True
,
return_tokenizer
=
False
,
dtype
=
dtype
,
device
=
device
,
quantized
=
self
.
quantized
,
quant_scheme
=
quant_scheme
)
self
.
model
=
self
.
model
.
eval
().
requires_grad_
(
False
)
weight_dict
=
torch
.
load
(
self
.
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
keys
=
list
(
weight_dict
.
keys
())
for
key
in
keys
:
if
"textual"
in
key
:
weight_dict
.
pop
(
key
)
logger
.
info
(
f
"Start Loading weights from
{
self
.
checkpoint_path
}
"
)
self
.
model
.
load_state_dict
(
weight_dict
)
logger
.
info
(
f
"End Loading weights from
{
self
.
checkpoint_path
}
"
)
def
visual
(
self
,
videos
,
args
):
if
args
.
cpu_offload
:
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
self
.
to_cuda
()
# preprocess
size
=
(
self
.
model
.
image_size
,)
*
2
...
...
@@ -422,7 +445,7 @@ class CLIPModel:
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
self
.
dtype
):
out
=
self
.
model
.
visual
(
videos
,
use_31_block
=
True
)
if
args
.
cpu_offload
:
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
self
.
to_cpu
()
return
out
...
...
lightx2v/models/networks/wan/audio_adapter.py
View file @
e08c4f90
import
flash_attn
try
:
import
flash_attn
except
ModuleNotFoundError
:
flash_attn
=
None
import
math
import
torch
import
torch.nn
as
nn
...
...
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment