Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
04812de2
Unverified
Commit
04812de2
authored
Sep 29, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Sep 29, 2025
Browse files
Refactor Config System (#338)
parent
6a658f42
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
157 additions
and
109 deletions
+157
-109
configs/wan22/wan_moe_i2v_distill.json
configs/wan22/wan_moe_i2v_distill.json
+10
-3
configs/wan22/wan_moe_i2v_distill_quant.json
configs/wan22/wan_moe_i2v_distill_quant.json
+10
-3
configs/wan22/wan_moe_t2v.json
configs/wan22/wan_moe_t2v.json
+4
-2
configs/wan22/wan_moe_t2v_distill.json
configs/wan22/wan_moe_t2v_distill.json
+10
-3
configs/wan22/wan_ti2v_i2v.json
configs/wan22/wan_ti2v_i2v.json
+5
-2
configs/wan22/wan_ti2v_i2v_4090.json
configs/wan22/wan_ti2v_i2v_4090.json
+5
-2
configs/wan22/wan_ti2v_t2v.json
configs/wan22/wan_ti2v_t2v.json
+5
-2
configs/wan22/wan_ti2v_t2v_4090.json
configs/wan22/wan_ti2v_t2v_4090.json
+5
-2
docs/EN/source/method_tutorials/video_frame_interpolation.md
docs/EN/source/method_tutorials/video_frame_interpolation.md
+3
-3
docs/ZH_CN/source/method_tutorials/video_frame_interpolation.md
...H_CN/source/method_tutorials/video_frame_interpolation.md
+3
-3
lightx2v/deploy/worker/hub.py
lightx2v/deploy/worker/hub.py
+3
-3
lightx2v/infer.py
lightx2v/infer.py
+11
-6
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+25
-17
lightx2v/models/networks/wan/infer/audio/pre_infer.py
lightx2v/models/networks/wan/infer/audio/pre_infer.py
+2
-2
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
...2v/models/networks/wan/infer/causvid/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+32
-32
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+1
-1
lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py
...dels/networks/wan/infer/self_forcing/transformer_infer.py
+5
-5
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+7
-7
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+9
-9
No files found.
configs/wan22/wan_moe_i2v_distill.json
View file @
04812de2
...
@@ -7,8 +7,10 @@
...
@@ -7,8 +7,10 @@
"self_attn_1_type"
:
"flash_attn3"
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
"sample_guide_scale"
:
[
3.5
,
3.5
],
3.5
,
3.5
],
"sample_shift"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
false
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
"cpu_offload"
:
true
,
...
@@ -17,5 +19,10 @@
...
@@ -17,5 +19,10 @@
"vae_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"use_image_encoder"
:
false
,
"use_image_encoder"
:
false
,
"boundary_step_index"
:
2
,
"boundary_step_index"
:
2
,
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
]
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
]
}
}
configs/wan22/wan_moe_i2v_distill_quant.json
View file @
04812de2
...
@@ -7,8 +7,10 @@
...
@@ -7,8 +7,10 @@
"self_attn_1_type"
:
"flash_attn3"
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
"sample_guide_scale"
:
[
3.5
,
3.5
],
3.5
,
3.5
],
"sample_shift"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
false
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
"cpu_offload"
:
true
,
...
@@ -17,7 +19,12 @@
...
@@ -17,7 +19,12 @@
"vae_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"use_image_encoder"
:
false
,
"use_image_encoder"
:
false
,
"boundary_step_index"
:
2
,
"boundary_step_index"
:
2
,
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
],
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
],
"mm_config"
:
{
"mm_config"
:
{
"mm_type"
:
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
"mm_type"
:
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
},
},
...
...
configs/wan22/wan_moe_t2v.json
View file @
04812de2
...
@@ -7,8 +7,10 @@
...
@@ -7,8 +7,10 @@
"self_attn_1_type"
:
"flash_attn3"
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
"sample_guide_scale"
:
[
4.0
,
3.0
],
4.0
,
3.0
],
"sample_shift"
:
12.0
,
"sample_shift"
:
12.0
,
"enable_cfg"
:
true
,
"enable_cfg"
:
true
,
"cpu_offload"
:
true
,
"cpu_offload"
:
true
,
...
...
configs/wan22/wan_moe_t2v_distill.json
View file @
04812de2
...
@@ -7,8 +7,10 @@
...
@@ -7,8 +7,10 @@
"self_attn_1_type"
:
"flash_attn3"
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
"sample_guide_scale"
:
[
4.0
,
3.0
],
4.0
,
3.0
],
"sample_shift"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
false
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
"cpu_offload"
:
true
,
...
@@ -16,7 +18,12 @@
...
@@ -16,7 +18,12 @@
"t5_cpu_offload"
:
false
,
"t5_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"boundary_step_index"
:
2
,
"boundary_step_index"
:
2
,
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
],
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
],
"lora_configs"
:
[
"lora_configs"
:
[
{
{
"name"
:
"low_noise_model"
,
"name"
:
"low_noise_model"
,
...
...
configs/wan22/wan_ti2v_i2v.json
View file @
04812de2
...
@@ -5,11 +5,14 @@
...
@@ -5,11 +5,14 @@
"target_height"
:
704
,
"target_height"
:
704
,
"target_width"
:
1280
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
"enable_cfg"
:
true
,
...
...
configs/wan22/wan_ti2v_i2v_4090.json
View file @
04812de2
...
@@ -5,11 +5,14 @@
...
@@ -5,11 +5,14 @@
"target_height"
:
704
,
"target_height"
:
704
,
"target_width"
:
1280
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
"enable_cfg"
:
true
,
...
...
configs/wan22/wan_ti2v_t2v.json
View file @
04812de2
...
@@ -5,11 +5,14 @@
...
@@ -5,11 +5,14 @@
"target_height"
:
704
,
"target_height"
:
704
,
"target_width"
:
1280
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
"enable_cfg"
:
true
,
...
...
configs/wan22/wan_ti2v_t2v_4090.json
View file @
04812de2
...
@@ -5,11 +5,14 @@
...
@@ -5,11 +5,14 @@
"target_height"
:
704
,
"target_height"
:
704
,
"target_width"
:
1280
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
"enable_cfg"
:
true
,
...
...
docs/EN/source/method_tutorials/video_frame_interpolation.md
View file @
04812de2
...
@@ -67,7 +67,7 @@ python lightx2v/infer.py \
...
@@ -67,7 +67,7 @@ python lightx2v/infer.py \
--model_path
/path/to/model
\
--model_path
/path/to/model
\
--config_json
./configs/video_frame_interpolation/wan_t2v.json
\
--config_json
./configs/video_frame_interpolation/wan_t2v.json
\
--prompt
"A beautiful sunset over the ocean"
\
--prompt
"A beautiful sunset over the ocean"
\
--save_
video
_path
./output.mp4
--save_
result
_path
./output.mp4
```
```
### Configuration Parameters
### Configuration Parameters
...
@@ -136,7 +136,7 @@ python lightx2v/infer.py \
...
@@ -136,7 +136,7 @@ python lightx2v/infer.py \
--model_path
./models/wan2.1
\
--model_path
./models/wan2.1
\
--config_json
./wan_t2v_vfi_32fps.json
\
--config_json
./wan_t2v_vfi_32fps.json
\
--prompt
"A cat playing in the garden"
\
--prompt
"A cat playing in the garden"
\
--save_
video
_path
./output_32fps.mp4
--save_
result
_path
./output_32fps.mp4
```
```
### Higher Frame Rate Enhancement
### Higher Frame Rate Enhancement
...
@@ -170,7 +170,7 @@ python lightx2v/infer.py \
...
@@ -170,7 +170,7 @@ python lightx2v/infer.py \
--config_json
./wan_i2v_vfi_60fps.json
\
--config_json
./wan_i2v_vfi_60fps.json
\
--image_path
./input.jpg
\
--image_path
./input.jpg
\
--prompt
"Smooth camera movement"
\
--prompt
"Smooth camera movement"
\
--save_
video
_path
./output_60fps.mp4
--save_
result
_path
./output_60fps.mp4
```
```
## Performance Considerations
## Performance Considerations
...
...
docs/ZH_CN/source/method_tutorials/video_frame_interpolation.md
View file @
04812de2
...
@@ -67,7 +67,7 @@ python lightx2v/infer.py \
...
@@ -67,7 +67,7 @@ python lightx2v/infer.py \
--model_path
/path/to/model
\
--model_path
/path/to/model
\
--config_json
./configs/video_frame_interpolation/wan_t2v.json
\
--config_json
./configs/video_frame_interpolation/wan_t2v.json
\
--prompt
"美丽的海上日落"
\
--prompt
"美丽的海上日落"
\
--save_
video
_path
./output.mp4
--save_
result
_path
./output.mp4
```
```
### 配置参数说明
### 配置参数说明
...
@@ -136,7 +136,7 @@ python lightx2v/infer.py \
...
@@ -136,7 +136,7 @@ python lightx2v/infer.py \
--model_path
./models/wan2.1
\
--model_path
./models/wan2.1
\
--config_json
./wan_t2v_vfi_32fps.json
\
--config_json
./wan_t2v_vfi_32fps.json
\
--prompt
"一只小猫在花园里玩耍"
\
--prompt
"一只小猫在花园里玩耍"
\
--save_
video
_path
./output_32fps.mp4
--save_
result
_path
./output_32fps.mp4
```
```
### 更高帧率增强
### 更高帧率增强
...
@@ -170,7 +170,7 @@ python lightx2v/infer.py \
...
@@ -170,7 +170,7 @@ python lightx2v/infer.py \
--config_json
./wan_i2v_vfi_60fps.json
\
--config_json
./wan_i2v_vfi_60fps.json
\
--image_path
./input.jpg
\
--image_path
./input.jpg
\
--prompt
"平滑的相机运动"
\
--prompt
"平滑的相机运动"
\
--save_
video
_path
./output_60fps.mp4
--save_
result
_path
./output_60fps.mp4
```
```
## 性能考虑
## 性能考虑
...
...
lightx2v/deploy/worker/hub.py
View file @
04812de2
...
@@ -23,7 +23,7 @@ from lightx2v.utils.utils import seed_all
...
@@ -23,7 +23,7 @@ from lightx2v.utils.utils import seed_all
class
BaseWorker
:
class
BaseWorker
:
@
ProfilingContext4DebugL1
(
"Init Worker Worker Cost:"
)
@
ProfilingContext4DebugL1
(
"Init Worker Worker Cost:"
)
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
args
.
save_
video
_path
=
""
args
.
save_
result
_path
=
""
config
=
set_config
(
args
)
config
=
set_config
(
args
)
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
seed_all
(
config
.
seed
)
seed_all
(
config
.
seed
)
...
@@ -49,7 +49,7 @@ class BaseWorker:
...
@@ -49,7 +49,7 @@ class BaseWorker:
self
.
runner
.
config
[
"prompt"
]
=
params
[
"prompt"
]
self
.
runner
.
config
[
"prompt"
]
=
params
[
"prompt"
]
self
.
runner
.
config
[
"negative_prompt"
]
=
params
.
get
(
"negative_prompt"
,
""
)
self
.
runner
.
config
[
"negative_prompt"
]
=
params
.
get
(
"negative_prompt"
,
""
)
self
.
runner
.
config
[
"image_path"
]
=
params
.
get
(
"image_path"
,
""
)
self
.
runner
.
config
[
"image_path"
]
=
params
.
get
(
"image_path"
,
""
)
self
.
runner
.
config
[
"save_
video
_path"
]
=
params
.
get
(
"save_
video
_path"
,
""
)
self
.
runner
.
config
[
"save_
result
_path"
]
=
params
.
get
(
"save_
result
_path"
,
""
)
self
.
runner
.
config
[
"seed"
]
=
params
.
get
(
"seed"
,
self
.
fixed_config
.
get
(
"seed"
,
42
))
self
.
runner
.
config
[
"seed"
]
=
params
.
get
(
"seed"
,
self
.
fixed_config
.
get
(
"seed"
,
42
))
self
.
runner
.
config
[
"audio_path"
]
=
params
.
get
(
"audio_path"
,
""
)
self
.
runner
.
config
[
"audio_path"
]
=
params
.
get
(
"audio_path"
,
""
)
...
@@ -92,7 +92,7 @@ class BaseWorker:
...
@@ -92,7 +92,7 @@ class BaseWorker:
if
stream_video_path
is
not
None
:
if
stream_video_path
is
not
None
:
tmp_video_path
=
stream_video_path
tmp_video_path
=
stream_video_path
params
[
"save_
video
_path"
]
=
tmp_video_path
params
[
"save_
result
_path"
]
=
tmp_video_path
return
tmp_video_path
,
output_video_path
return
tmp_video_path
,
output_video_path
async
def
prepare_dit_inputs
(
self
,
inputs
,
data_manager
):
async
def
prepare_dit_inputs
(
self
,
inputs
,
data_manager
):
...
...
lightx2v/infer.py
View file @
04812de2
...
@@ -17,6 +17,7 @@ from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401
...
@@ -17,6 +17,7 @@ from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401
from
lightx2v.models.runners.wan.wan_skyreels_v2_df_runner
import
WanSkyreelsV2DFRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_skyreels_v2_df_runner
import
WanSkyreelsV2DFRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_vace_runner
import
WanVaceRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_vace_runner
import
WanVaceRunner
# noqa: F401
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.input_info
import
set_input_info
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
...
@@ -24,15 +25,15 @@ from lightx2v.utils.utils import seed_all
...
@@ -24,15 +25,15 @@ from lightx2v.utils.utils import seed_all
def
init_runner
(
config
):
def
init_runner
(
config
):
seed_all
(
config
.
seed
)
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
runner
=
RUNNER_REGISTER
[
config
.
model_cls
](
config
)
runner
=
RUNNER_REGISTER
[
config
[
"
model_cls
"
]
](
config
)
runner
.
init_modules
()
runner
.
init_modules
()
return
runner
return
runner
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"The seed for random generator"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--model_cls"
,
"--model_cls"
,
type
=
str
,
type
=
str
,
...
@@ -58,7 +59,7 @@ def main():
...
@@ -58,7 +59,7 @@ def main():
default
=
"wan2.1"
,
default
=
"wan2.1"
,
)
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
,
"t2i"
,
"i2i"
,
"flf2v"
,
"vace"
,
"animate"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
,
"t2i"
,
"i2i"
,
"flf2v"
,
"vace"
,
"animate"
,
"s2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--sf_model_path"
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--sf_model_path"
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
...
@@ -91,13 +92,16 @@ def main():
...
@@ -91,13 +92,16 @@ def main():
help
=
"The file of the source mask. Default None."
,
help
=
"The file of the source mask. Default None."
,
)
)
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
None
,
help
=
"The path to save video path/file"
)
parser
.
add_argument
(
"--save_result_path"
,
type
=
str
,
default
=
None
,
help
=
"The path to save video path/file"
)
parser
.
add_argument
(
"--return_result_tensor"
,
action
=
"store_true"
,
help
=
"Whether to return result tensor. (Useful for comfyui)"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
seed_all
(
args
.
seed
)
# set config
# set config
config
=
set_config
(
args
)
config
=
set_config
(
args
)
if
config
.
parallel
:
if
config
[
"
parallel
"
]
:
dist
.
init_process_group
(
backend
=
"nccl"
)
dist
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
set_parallel_config
(
config
)
set_parallel_config
(
config
)
...
@@ -106,7 +110,8 @@ def main():
...
@@ -106,7 +110,8 @@ def main():
with
ProfilingContext4DebugL1
(
"Total Cost"
):
with
ProfilingContext4DebugL1
(
"Total Cost"
):
runner
=
init_runner
(
config
)
runner
=
init_runner
(
config
)
runner
.
run_pipeline
()
input_info
=
set_input_info
(
args
)
runner
.
run_pipeline
(
input_info
)
# Clean up distributed process group
# Clean up distributed process group
if
dist
.
is_initialized
():
if
dist
.
is_initialized
():
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
04812de2
...
@@ -2,6 +2,7 @@ import os
...
@@ -2,6 +2,7 @@ import os
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.models.networks.wan.infer.audio.post_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.audio.post_infer
import
WanAudioPostInfer
...
@@ -35,11 +36,11 @@ class WanAudioModel(WanModel):
...
@@ -35,11 +36,11 @@ class WanAudioModel(WanModel):
raise
ValueError
(
f
"Unsupported quant_scheme:
{
self
.
config
.
get
(
'adapter_quant_scheme'
,
None
)
}
"
)
raise
ValueError
(
f
"Unsupported quant_scheme:
{
self
.
config
.
get
(
'adapter_quant_scheme'
,
None
)
}
"
)
else
:
else
:
adapter_model_name
=
"audio_adapter_model.safetensors"
adapter_model_name
=
"audio_adapter_model.safetensors"
self
.
config
.
adapter_model_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
adapter_model_name
)
self
.
config
[
"
adapter_model_path
"
]
=
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
adapter_model_name
)
adapter_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
adapter_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
load_from_rank0
=
self
.
config
.
get
(
"load_from_rank0"
,
False
)
load_from_rank0
=
self
.
config
.
get
(
"load_from_rank0"
,
False
)
self
.
adapter_weights_dict
=
load_weights
(
self
.
config
.
adapter_model_path
,
cpu_offload
=
adapter_offload
,
remove_key
=
"audio"
,
load_from_rank0
=
load_from_rank0
)
self
.
adapter_weights_dict
=
load_weights
(
self
.
config
[
"
adapter_model_path
"
]
,
cpu_offload
=
adapter_offload
,
remove_key
=
"audio"
,
load_from_rank0
=
load_from_rank0
)
if
not
adapter_offload
:
if
not
adapter_offload
:
if
not
dist
.
is_initialized
()
or
not
load_from_rank0
:
if
not
dist
.
is_initialized
()
or
not
load_from_rank0
:
for
key
in
self
.
adapter_weights_dict
:
for
key
in
self
.
adapter_weights_dict
:
...
@@ -51,17 +52,17 @@ class WanAudioModel(WanModel):
...
@@ -51,17 +52,17 @@ class WanAudioModel(WanModel):
self
.
post_infer_class
=
WanAudioPostInfer
self
.
post_infer_class
=
WanAudioPostInfer
self
.
transformer_infer_class
=
WanAudioTransformerInfer
self
.
transformer_infer_class
=
WanAudioTransformerInfer
def
get_graph_name
(
self
,
shape
,
audio_num
):
def
get_graph_name
(
self
,
shape
,
audio_num
,
with_mask
=
True
):
return
f
"graph_
{
shape
[
0
]
}
x
{
shape
[
1
]
}
_
{
audio_num
}
audio"
return
f
"graph_
{
shape
[
0
]
}
x
{
shape
[
1
]
}
_audio_num
_
{
audio
_num
}
_mask_
{
with_mask
}
"
def
start_compile
(
self
,
shape
,
audio_num
):
def
start_compile
(
self
,
shape
,
audio_num
,
with_mask
=
True
):
graph_name
=
self
.
get_graph_name
(
shape
,
audio_num
)
graph_name
=
self
.
get_graph_name
(
shape
,
audio_num
,
with_mask
)
logger
.
info
(
f
"[Compile] Compile shape:
{
shape
}
, audio_num:
{
audio_num
}
, graph_name:
{
graph_name
}
"
)
logger
.
info
(
f
"[Compile] Compile shape:
{
shape
}
, audio_num:
{
audio_num
}
, graph_name:
{
graph_name
}
"
)
target_video_length
=
self
.
config
.
get
(
"target_video_length"
,
81
)
target_video_length
=
self
.
config
.
get
(
"target_video_length"
,
81
)
latents_length
=
(
target_video_length
-
1
)
//
16
*
4
+
1
latents_length
=
(
target_video_length
-
1
)
//
16
*
4
+
1
latents_h
=
shape
[
0
]
//
self
.
config
.
vae_stride
[
1
]
latents_h
=
shape
[
0
]
//
self
.
config
[
"
vae_stride
"
]
[
1
]
latents_w
=
shape
[
1
]
//
self
.
config
.
vae_stride
[
2
]
latents_w
=
shape
[
1
]
//
self
.
config
[
"
vae_stride
"
]
[
2
]
new_inputs
=
{}
new_inputs
=
{}
new_inputs
[
"text_encoder_output"
]
=
{}
new_inputs
[
"text_encoder_output"
]
=
{}
...
@@ -73,7 +74,11 @@ class WanAudioModel(WanModel):
...
@@ -73,7 +74,11 @@ class WanAudioModel(WanModel):
new_inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
torch
.
randn
(
16
,
1
,
latents_h
,
latents_w
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
torch
.
randn
(
16
,
1
,
latents_h
,
latents_w
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"audio_encoder_output"
]
=
torch
.
randn
(
audio_num
,
latents_length
,
128
,
1024
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"audio_encoder_output"
]
=
torch
.
randn
(
audio_num
,
latents_length
,
128
,
1024
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"person_mask_latens"
]
=
torch
.
zeros
(
audio_num
,
1
,
(
latents_h
//
2
),
(
latents_w
//
2
),
dtype
=
torch
.
int8
).
cuda
()
if
with_mask
:
new_inputs
[
"person_mask_latens"
]
=
torch
.
zeros
(
audio_num
,
1
,
(
latents_h
//
2
),
(
latents_w
//
2
),
dtype
=
torch
.
int8
).
cuda
()
else
:
assert
audio_num
==
1
,
"audio_num must be 1 when with_mask is False"
new_inputs
[
"person_mask_latens"
]
=
None
new_inputs
[
"previmg_encoder_output"
]
=
{}
new_inputs
[
"previmg_encoder_output"
]
=
{}
new_inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
=
torch
.
randn
(
16
,
latents_length
,
latents_h
,
latents_w
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
=
torch
.
randn
(
16
,
latents_length
,
latents_h
,
latents_w
,
dtype
=
torch
.
bfloat16
).
cuda
()
...
@@ -90,19 +95,21 @@ class WanAudioModel(WanModel):
...
@@ -90,19 +95,21 @@ class WanAudioModel(WanModel):
self
.
enable_compile_mode
(
"_infer_cond_uncond"
)
self
.
enable_compile_mode
(
"_infer_cond_uncond"
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
[
"
model_cls
"
]
:
self
.
to_cuda
()
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
pre_weight
.
to_cuda
()
self
.
transformer_weights
.
non_block_weights_to_cuda
()
self
.
transformer_weights
.
non_block_weights_to_cuda
()
max_audio_num_num
=
self
.
config
.
get
(
"compile_max_audios"
,
1
)
max_audio_num_num
=
self
.
config
.
get
(
"compile_max_audios"
,
3
)
for
audio_num
in
range
(
1
,
max_audio_num_num
+
1
):
for
audio_num
in
range
(
1
,
max_audio_num_num
+
1
):
for
shape
in
compile_shapes
:
for
shape
in
compile_shapes
:
self
.
start_compile
(
shape
,
audio_num
)
self
.
start_compile
(
shape
,
audio_num
,
with_mask
=
True
)
if
audio_num
==
1
:
self
.
start_compile
(
shape
,
audio_num
,
with_mask
=
False
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
[
"
model_cls
"
]
:
self
.
to_cpu
()
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
...
@@ -115,9 +122,10 @@ class WanAudioModel(WanModel):
...
@@ -115,9 +122,10 @@ class WanAudioModel(WanModel):
for
shape
in
compile_shapes
:
for
shape
in
compile_shapes
:
assert
shape
in
[[
480
,
832
],
[
544
,
960
],
[
720
,
1280
],
[
832
,
480
],
[
960
,
544
],
[
1280
,
720
],
[
480
,
480
],
[
576
,
576
],
[
704
,
704
],
[
960
,
960
]]
assert
shape
in
[[
480
,
832
],
[
544
,
960
],
[
720
,
1280
],
[
832
,
480
],
[
960
,
544
],
[
1280
,
720
],
[
480
,
480
],
[
576
,
576
],
[
704
,
704
],
[
960
,
960
]]
def
select_graph_for_compile
(
self
):
def
select_graph_for_compile
(
self
,
input_info
):
logger
.
info
(
f
"tgt_h, tgt_w :
{
self
.
config
.
get
(
'tgt_h'
)
}
,
{
self
.
config
.
get
(
'tgt_w'
)
}
, audio_num:
{
self
.
config
.
get
(
'audio_num'
)
}
"
)
logger
.
info
(
f
"target_h, target_w :
{
input_info
.
target_shape
[
0
]
}
,
{
input_info
.
target_shape
[
1
]
}
, audio_num:
{
input_info
.
audio_num
}
"
)
self
.
select_graph
(
"_infer_cond_uncond"
,
f
"graph_
{
self
.
config
.
get
(
'tgt_h'
)
}
x
{
self
.
config
.
get
(
'tgt_w'
)
}
_
{
self
.
config
.
get
(
'audio_num'
)
}
audio"
)
graph_name
=
self
.
get_graph_name
(
input_info
.
target_shape
,
input_info
.
audio_num
,
with_mask
=
input_info
.
with_mask
)
self
.
select_graph
(
"_infer_cond_uncond"
,
graph_name
)
logger
.
info
(
f
"[Compile] Compile status:
{
self
.
get_compile_status
()
}
"
)
logger
.
info
(
f
"[Compile] Compile status:
{
self
.
get_compile_status
()
}
"
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -138,7 +146,7 @@ class WanAudioModel(WanModel):
...
@@ -138,7 +146,7 @@ class WanAudioModel(WanModel):
if
person_mask_latens
is
not
None
:
if
person_mask_latens
is
not
None
:
pre_infer_out
.
adapter_args
[
"person_mask_latens"
]
=
torch
.
chunk
(
person_mask_latens
,
world_size
,
dim
=
1
)[
cur_rank
]
pre_infer_out
.
adapter_args
[
"person_mask_latens"
]
=
torch
.
chunk
(
person_mask_latens
,
world_size
,
dim
=
1
)[
cur_rank
]
if
self
.
config
[
"model_cls"
]
in
[
"wan2.2"
,
"wan2.2_audio"
]
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"model_cls"
]
in
[
"wan2.2"
,
"wan2.2_audio"
]
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
embed
,
embed0
=
pre_infer_out
.
embed
,
pre_infer_out
.
embed0
embed
,
embed0
=
pre_infer_out
.
embed
,
pre_infer_out
.
embed0
padding_size
=
(
world_size
-
(
embed
.
shape
[
0
]
%
world_size
))
%
world_size
padding_size
=
(
world_size
-
(
embed
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
if
padding_size
>
0
:
...
...
lightx2v/models/networks/wan/infer/audio/pre_infer.py
View file @
04812de2
...
@@ -33,7 +33,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -33,7 +33,7 @@ class WanAudioPreInfer(WanPreInfer):
infer_condition
,
latents
,
timestep_input
=
self
.
scheduler
.
infer_condition
,
self
.
scheduler
.
latents
,
self
.
scheduler
.
timestep_input
infer_condition
,
latents
,
timestep_input
=
self
.
scheduler
.
infer_condition
,
self
.
scheduler
.
latents
,
self
.
scheduler
.
timestep_input
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
hidden_states
=
latents
hidden_states
=
latents
if
self
.
config
.
model_cls
!=
"wan2.2_audio"
:
if
self
.
config
[
"
model_cls
"
]
!=
"wan2.2_audio"
:
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
],
dim
=
0
)
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
],
dim
=
0
)
...
@@ -101,7 +101,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -101,7 +101,7 @@ class WanAudioPreInfer(WanPreInfer):
del
out
del
out
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
self
.
task
==
"i2v"
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_clip
=
weights
.
proj_0
.
apply
(
clip_fea
)
context_clip
=
weights
.
proj_0
.
apply
(
clip_fea
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
clip_fea
del
clip_fea
...
...
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
View file @
04812de2
...
@@ -140,7 +140,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
...
@@ -140,7 +140,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
def
infer_cross_attn
(
self
,
weights
,
x
,
context
,
block_idx
):
def
infer_cross_attn
(
self
,
weights
,
x
,
context
,
block_idx
):
norm3_out
=
weights
.
norm3
.
apply
(
x
)
norm3_out
=
weights
.
norm3
.
apply
(
x
)
if
self
.
task
==
"i2v"
:
if
self
.
task
in
[
"i2v"
,
"s2v"
]
:
context_img
=
context
[:
257
]
context_img
=
context
[:
257
]
context
=
context
[
257
:]
context
=
context
[
257
:]
...
@@ -169,7 +169,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
...
@@ -169,7 +169,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
)
)
if
self
.
task
==
"i2v"
:
if
self
.
task
in
[
"i2v"
,
"s2v"
]
:
k_img
=
weights
.
cross_attn_norm_k_img
.
apply
(
weights
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
k_img
=
weights
.
cross_attn_norm_k_img
.
apply
(
weights
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
v_img
=
weights
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
v_img
=
weights
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
...
...
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
04812de2
...
@@ -25,22 +25,22 @@ class WanTransformerInferCaching(WanOffloadTransformerInfer):
...
@@ -25,22 +25,22 @@ class WanTransformerInferCaching(WanOffloadTransformerInfer):
class
WanTransformerInferTeaCaching
(
WanTransformerInferCaching
):
class
WanTransformerInferTeaCaching
(
WanTransformerInferCaching
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
teacache_thresh
=
config
.
teacache_thresh
self
.
teacache_thresh
=
config
[
"
teacache_thresh
"
]
self
.
accumulated_rel_l1_distance_even
=
0
self
.
accumulated_rel_l1_distance_even
=
0
self
.
previous_e0_even
=
None
self
.
previous_e0_even
=
None
self
.
previous_residual_even
=
None
self
.
previous_residual_even
=
None
self
.
accumulated_rel_l1_distance_odd
=
0
self
.
accumulated_rel_l1_distance_odd
=
0
self
.
previous_e0_odd
=
None
self
.
previous_e0_odd
=
None
self
.
previous_residual_odd
=
None
self
.
previous_residual_odd
=
None
self
.
use_ret_steps
=
config
.
use_ret_steps
self
.
use_ret_steps
=
config
[
"
use_ret_steps
"
]
if
self
.
use_ret_steps
:
if
self
.
use_ret_steps
:
self
.
coefficients
=
self
.
config
.
coefficients
[
0
]
self
.
coefficients
=
self
.
config
[
"
coefficients
"
]
[
0
]
self
.
ret_steps
=
5
self
.
ret_steps
=
5
self
.
cutoff_steps
=
self
.
config
.
infer_steps
self
.
cutoff_steps
=
self
.
config
[
"
infer_steps
"
]
else
:
else
:
self
.
coefficients
=
self
.
config
.
coefficients
[
1
]
self
.
coefficients
=
self
.
config
[
"
coefficients
"
]
[
1
]
self
.
ret_steps
=
1
self
.
ret_steps
=
1
self
.
cutoff_steps
=
self
.
config
.
infer_steps
-
1
self
.
cutoff_steps
=
self
.
config
[
"
infer_steps
"
]
-
1
# calculate should_calc
# calculate should_calc
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -216,7 +216,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInferCaching, BaseTaylorCac
...
@@ -216,7 +216,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInferCaching, BaseTaylorCac
else
:
else
:
x
=
self
.
infer_using_cache
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
self
.
infer_using_cache
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
self
.
switch_status
()
return
x
return
x
...
@@ -353,7 +353,7 @@ class WanTransformerInferAdaCaching(WanTransformerInferCaching):
...
@@ -353,7 +353,7 @@ class WanTransformerInferAdaCaching(WanTransformerInferCaching):
else
:
else
:
x
=
self
.
infer_using_cache
(
xt
)
x
=
self
.
infer_using_cache
(
xt
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
self
.
switch_status
()
return
x
return
x
...
@@ -515,7 +515,7 @@ class AdaArgs:
...
@@ -515,7 +515,7 @@ class AdaArgs:
# Moreg related attributes
# Moreg related attributes
self
.
previous_moreg
=
1.0
self
.
previous_moreg
=
1.0
self
.
moreg_strides
=
[
1
]
self
.
moreg_strides
=
[
1
]
self
.
moreg_steps
=
[
int
(
0.1
*
config
.
infer_steps
),
int
(
0.9
*
config
.
infer_steps
)]
self
.
moreg_steps
=
[
int
(
0.1
*
config
[
"
infer_steps
"
]
),
int
(
0.9
*
config
[
"
infer_steps
"
]
)]
self
.
moreg_hyp
=
[
0.385
,
8
,
1
,
2
]
self
.
moreg_hyp
=
[
0.385
,
8
,
1
,
2
]
self
.
mograd_mul
=
10
self
.
mograd_mul
=
10
self
.
spatial_dim
=
1536
self
.
spatial_dim
=
1536
...
@@ -525,7 +525,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
...
@@ -525,7 +525,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
cnt
=
0
self
.
cnt
=
0
self
.
teacache_thresh
=
config
.
teacache_thresh
self
.
teacache_thresh
=
config
[
"
teacache_thresh
"
]
self
.
accumulated_rel_l1_distance_even
=
0
self
.
accumulated_rel_l1_distance_even
=
0
self
.
previous_e0_even
=
None
self
.
previous_e0_even
=
None
self
.
previous_residual_even
=
None
self
.
previous_residual_even
=
None
...
@@ -534,15 +534,15 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
...
@@ -534,15 +534,15 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
self
.
previous_residual_odd
=
None
self
.
previous_residual_odd
=
None
self
.
cache_even
=
{}
self
.
cache_even
=
{}
self
.
cache_odd
=
{}
self
.
cache_odd
=
{}
self
.
use_ret_steps
=
config
.
use_ret_steps
self
.
use_ret_steps
=
config
[
"
use_ret_steps
"
]
if
self
.
use_ret_steps
:
if
self
.
use_ret_steps
:
self
.
coefficients
=
self
.
config
.
coefficients
[
0
]
self
.
coefficients
=
self
.
config
[
"
coefficients
"
]
[
0
]
self
.
ret_steps
=
5
*
2
self
.
ret_steps
=
5
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
self
.
cutoff_steps
=
self
.
config
[
"
infer_steps
"
]
*
2
else
:
else
:
self
.
coefficients
=
self
.
config
.
coefficients
[
1
]
self
.
coefficients
=
self
.
config
[
"
coefficients
"
]
[
1
]
self
.
ret_steps
=
1
*
2
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
self
.
cutoff_steps
=
self
.
config
[
"
infer_steps
"
]
*
2
-
2
# 1. get taylor step_diff when there is two caching_records in scheduler
# 1. get taylor step_diff when there is two caching_records in scheduler
def
get_taylor_step_diff
(
self
):
def
get_taylor_step_diff
(
self
):
...
@@ -625,7 +625,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
...
@@ -625,7 +625,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
x
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
self
.
switch_status
()
self
.
cnt
+=
1
self
.
cnt
+=
1
...
@@ -690,12 +690,12 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching):
...
@@ -690,12 +690,12 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
residual_diff_threshold
=
config
.
residual_diff_threshold
self
.
residual_diff_threshold
=
config
[
"
residual_diff_threshold
"
]
self
.
prev_first_block_residual_even
=
None
self
.
prev_first_block_residual_even
=
None
self
.
prev_remaining_blocks_residual_even
=
None
self
.
prev_remaining_blocks_residual_even
=
None
self
.
prev_first_block_residual_odd
=
None
self
.
prev_first_block_residual_odd
=
None
self
.
prev_remaining_blocks_residual_odd
=
None
self
.
prev_remaining_blocks_residual_odd
=
None
self
.
downsample_factor
=
self
.
config
.
downsample_factor
self
.
downsample_factor
=
self
.
config
[
"
downsample_factor
"
]
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
ori_x
=
x
.
clone
()
ori_x
=
x
.
clone
()
...
@@ -727,7 +727,7 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching):
...
@@ -727,7 +727,7 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching):
else
:
else
:
x
=
self
.
infer_using_cache
(
x
)
x
=
self
.
infer_using_cache
(
x
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
self
.
switch_status
()
return
x
return
x
...
@@ -795,12 +795,12 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
...
@@ -795,12 +795,12 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
residual_diff_threshold
=
config
.
residual_diff_threshold
self
.
residual_diff_threshold
=
config
[
"
residual_diff_threshold
"
]
self
.
prev_front_blocks_residual_even
=
None
self
.
prev_front_blocks_residual_even
=
None
self
.
prev_middle_blocks_residual_even
=
None
self
.
prev_middle_blocks_residual_even
=
None
self
.
prev_front_blocks_residual_odd
=
None
self
.
prev_front_blocks_residual_odd
=
None
self
.
prev_middle_blocks_residual_odd
=
None
self
.
prev_middle_blocks_residual_odd
=
None
self
.
downsample_factor
=
self
.
config
.
downsample_factor
self
.
downsample_factor
=
self
.
config
[
"
downsample_factor
"
]
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
ori_x
=
x
.
clone
()
ori_x
=
x
.
clone
()
...
@@ -854,7 +854,7 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
...
@@ -854,7 +854,7 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
context
,
context
,
)
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
self
.
switch_status
()
return
x
return
x
...
@@ -921,8 +921,8 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
...
@@ -921,8 +921,8 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
class
WanTransformerInferDynamicBlock
(
WanTransformerInferCaching
):
class
WanTransformerInferDynamicBlock
(
WanTransformerInferCaching
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
residual_diff_threshold
=
config
.
residual_diff_threshold
self
.
residual_diff_threshold
=
config
[
"
residual_diff_threshold
"
]
self
.
downsample_factor
=
self
.
config
.
downsample_factor
self
.
downsample_factor
=
self
.
config
[
"
downsample_factor
"
]
self
.
block_in_cache_even
=
{
i
:
None
for
i
in
range
(
self
.
blocks_num
)}
self
.
block_in_cache_even
=
{
i
:
None
for
i
in
range
(
self
.
blocks_num
)}
self
.
block_residual_cache_even
=
{
i
:
None
for
i
in
range
(
self
.
blocks_num
)}
self
.
block_residual_cache_even
=
{
i
:
None
for
i
in
range
(
self
.
blocks_num
)}
...
@@ -992,10 +992,10 @@ class WanTransformerInferDynamicBlock(WanTransformerInferCaching):
...
@@ -992,10 +992,10 @@ class WanTransformerInferDynamicBlock(WanTransformerInferCaching):
class
WanTransformerInferMagCaching
(
WanTransformerInferCaching
):
class
WanTransformerInferMagCaching
(
WanTransformerInferCaching
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
magcache_thresh
=
config
.
magcache_thresh
self
.
magcache_thresh
=
config
[
"
magcache_thresh
"
]
self
.
K
=
config
.
magcache_K
self
.
K
=
config
[
"
magcache_K
"
]
self
.
retention_ratio
=
config
.
magcache_retention_ratio
self
.
retention_ratio
=
config
[
"
magcache_retention_ratio
"
]
self
.
mag_ratios
=
np
.
array
(
config
.
magcache_ratios
)
self
.
mag_ratios
=
np
.
array
(
config
[
"
magcache_ratios
"
]
)
# {True: cond_param, False: uncond_param}
# {True: cond_param, False: uncond_param}
self
.
accumulated_err
=
{
True
:
0.0
,
False
:
0.0
}
self
.
accumulated_err
=
{
True
:
0.0
,
False
:
0.0
}
self
.
accumulated_steps
=
{
True
:
0
,
False
:
0
}
self
.
accumulated_steps
=
{
True
:
0
,
False
:
0
}
...
@@ -1011,10 +1011,10 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
...
@@ -1011,10 +1011,10 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
step_index
=
self
.
scheduler
.
step_index
step_index
=
self
.
scheduler
.
step_index
infer_condition
=
self
.
scheduler
.
infer_condition
infer_condition
=
self
.
scheduler
.
infer_condition
if
self
.
config
.
magcache_calibration
:
if
self
.
config
[
"
magcache_calibration
"
]
:
skip_forward
=
False
skip_forward
=
False
else
:
else
:
if
step_index
>=
int
(
self
.
config
.
infer_steps
*
self
.
retention_ratio
):
if
step_index
>=
int
(
self
.
config
[
"
infer_steps
"
]
*
self
.
retention_ratio
):
# conditional and unconditional in one list
# conditional and unconditional in one list
cur_mag_ratio
=
self
.
mag_ratios
[
0
][
step_index
]
if
infer_condition
else
self
.
mag_ratios
[
1
][
step_index
]
cur_mag_ratio
=
self
.
mag_ratios
[
0
][
step_index
]
if
infer_condition
else
self
.
mag_ratios
[
1
][
step_index
]
# magnitude ratio between current step and the cached step
# magnitude ratio between current step and the cached step
...
@@ -1054,7 +1054,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
...
@@ -1054,7 +1054,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
previous_residual
=
previous_residual
.
cpu
()
previous_residual
=
previous_residual
.
cpu
()
if
self
.
config
.
magcache_calibration
and
step_index
>=
1
:
if
self
.
config
[
"
magcache_calibration
"
]
and
step_index
>=
1
:
norm_ratio
=
((
previous_residual
.
norm
(
dim
=-
1
)
/
self
.
residual_cache
[
infer_condition
].
norm
(
dim
=-
1
)).
mean
()).
item
()
norm_ratio
=
((
previous_residual
.
norm
(
dim
=-
1
)
/
self
.
residual_cache
[
infer_condition
].
norm
(
dim
=-
1
)).
mean
()).
item
()
norm_std
=
(
previous_residual
.
norm
(
dim
=-
1
)
/
self
.
residual_cache
[
infer_condition
].
norm
(
dim
=-
1
)).
std
().
item
()
norm_std
=
(
previous_residual
.
norm
(
dim
=-
1
)
/
self
.
residual_cache
[
infer_condition
].
norm
(
dim
=-
1
)).
std
().
item
()
cos_dis
=
(
1
-
F
.
cosine_similarity
(
previous_residual
,
self
.
residual_cache
[
infer_condition
],
dim
=-
1
,
eps
=
1e-8
)).
mean
().
item
()
cos_dis
=
(
1
-
F
.
cosine_similarity
(
previous_residual
,
self
.
residual_cache
[
infer_condition
],
dim
=-
1
,
eps
=
1e-8
)).
mean
().
item
()
...
@@ -1083,7 +1083,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
...
@@ -1083,7 +1083,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
self
.
accumulated_steps
=
{
True
:
0
,
False
:
0
}
self
.
accumulated_steps
=
{
True
:
0
,
False
:
0
}
self
.
accumulated_ratio
=
{
True
:
1.0
,
False
:
1.0
}
self
.
accumulated_ratio
=
{
True
:
1.0
,
False
:
1.0
}
self
.
residual_cache
=
{
True
:
None
,
False
:
None
}
self
.
residual_cache
=
{
True
:
None
,
False
:
None
}
if
self
.
config
.
magcache_calibration
:
if
self
.
config
[
"
magcache_calibration
"
]
:
print
(
"norm ratio"
)
print
(
"norm ratio"
)
print
(
self
.
norm_ratio
)
print
(
self
.
norm_ratio
)
print
(
"norm std"
)
print
(
"norm std"
)
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
04812de2
...
@@ -41,7 +41,7 @@ class WanPreInfer:
...
@@ -41,7 +41,7 @@ class WanPreInfer:
else
:
else
:
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]:
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]:
if
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
config
.
get
(
"use_image_encoder"
,
True
):
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
...
...
lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py
View file @
04812de2
...
@@ -39,12 +39,12 @@ class WanSFTransformerInfer(WanTransformerInfer):
...
@@ -39,12 +39,12 @@ class WanSFTransformerInfer(WanTransformerInfer):
else
:
else
:
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
dtype
=
torch
.
bfloat16
self
.
dtype
=
torch
.
bfloat16
sf_config
=
self
.
config
.
sf_config
sf_config
=
self
.
config
[
"
sf_config
"
]
self
.
local_attn_size
=
sf_config
.
local_attn_size
self
.
local_attn_size
=
sf_config
[
"
local_attn_size
"
]
self
.
max_attention_size
=
32760
if
self
.
local_attn_size
==
-
1
else
self
.
local_attn_size
*
1560
self
.
max_attention_size
=
32760
if
self
.
local_attn_size
==
-
1
else
self
.
local_attn_size
*
1560
self
.
num_frame_per_block
=
sf_config
.
num_frame_per_block
self
.
num_frame_per_block
=
sf_config
[
"
num_frame_per_block
"
]
self
.
num_transformer_blocks
=
sf_config
.
num_transformer_blocks
self
.
num_transformer_blocks
=
sf_config
[
"
num_transformer_blocks
"
]
self
.
frame_seq_length
=
sf_config
.
frame_seq_length
self
.
frame_seq_length
=
sf_config
[
"
frame_seq_length
"
]
self
.
_initialize_kv_cache
(
self
.
device
,
self
.
dtype
)
self
.
_initialize_kv_cache
(
self
.
device
,
self
.
dtype
)
self
.
_initialize_crossattn_cache
(
self
.
device
,
self
.
dtype
)
self
.
_initialize_crossattn_cache
(
self
.
device
,
self
.
dtype
)
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
04812de2
...
@@ -11,13 +11,13 @@ from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, comp
...
@@ -11,13 +11,13 @@ from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, comp
class
WanTransformerInfer
(
BaseTransformerInfer
):
class
WanTransformerInfer
(
BaseTransformerInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
config
=
config
self
.
task
=
config
.
task
self
.
task
=
config
[
"
task
"
]
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
self
.
blocks_num
=
config
.
num_layers
self
.
blocks_num
=
config
[
"
num_layers
"
]
self
.
phases_num
=
3
self
.
phases_num
=
3
self
.
has_post_adapter
=
False
self
.
has_post_adapter
=
False
self
.
num_heads
=
config
.
num_heads
self
.
num_heads
=
config
[
"
num_heads
"
]
self
.
head_dim
=
config
.
dim
//
config
.
num_heads
self
.
head_dim
=
config
[
"
dim
"
]
//
config
[
"
num_heads
"
]
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
parallel_attention
=
None
self
.
parallel_attention
=
None
if
config
.
get
(
"rotary_chunk"
,
False
):
if
config
.
get
(
"rotary_chunk"
,
False
):
...
@@ -203,7 +203,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -203,7 +203,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x
.
add_
(
y_out
*
gate_msa
.
squeeze
())
x
.
add_
(
y_out
*
gate_msa
.
squeeze
())
norm3_out
=
phase
.
norm3
.
apply
(
x
)
norm3_out
=
phase
.
norm3
.
apply
(
x
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_img
=
context
[:
257
]
context_img
=
context
[:
257
]
context
=
context
[
257
:]
context
=
context
[
257
:]
else
:
else
:
...
@@ -211,7 +211,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -211,7 +211,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
context
=
context
.
to
(
self
.
infer_dtype
)
context
=
context
.
to
(
self
.
infer_dtype
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_img
=
context_img
.
to
(
self
.
infer_dtype
)
context_img
=
context_img
.
to
(
self
.
infer_dtype
)
n
,
d
=
self
.
num_heads
,
self
.
head_dim
n
,
d
=
self
.
num_heads
,
self
.
head_dim
...
@@ -234,7 +234,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -234,7 +234,7 @@ class WanTransformerInfer(BaseTransformerInfer):
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
)
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
)
and
context_img
is
not
None
:
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
)
and
context_img
is
not
None
:
k_img
=
phase
.
cross_attn_norm_k_img
.
apply
(
phase
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
k_img
=
phase
.
cross_attn_norm_k_img
.
apply
(
phase
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
v_img
=
phase
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
v_img
=
phase
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
...
...
lightx2v/models/networks/wan/model.py
View file @
04812de2
...
@@ -62,15 +62,15 @@ class WanModel(CompiledMethodsMixin):
...
@@ -62,15 +62,15 @@ class WanModel(CompiledMethodsMixin):
self
.
init_empty_model
=
False
self
.
init_empty_model
=
False
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
self
.
dit_quantized
=
self
.
config
[
"
mm_config
"
]
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
if
self
.
dit_quantized
:
if
self
.
dit_quantized
:
dit_quant_scheme
=
self
.
config
.
mm_config
.
get
(
"mm_type"
).
split
(
"-"
)[
1
]
dit_quant_scheme
=
self
.
config
[
"
mm_config
"
]
.
get
(
"mm_type"
).
split
(
"-"
)[
1
]
if
self
.
config
.
model_cls
==
"wan2.1_distill"
:
if
self
.
config
[
"
model_cls
"
]
==
"wan2.1_distill"
:
dit_quant_scheme
=
"distill_"
+
dit_quant_scheme
dit_quant_scheme
=
"distill_"
+
dit_quant_scheme
if
dit_quant_scheme
==
"gguf"
:
if
dit_quant_scheme
==
"gguf"
:
self
.
dit_quantized_ckpt
=
find_gguf_model_path
(
config
,
"dit_quantized_ckpt"
,
subdir
=
dit_quant_scheme
)
self
.
dit_quantized_ckpt
=
find_gguf_model_path
(
config
,
"dit_quantized_ckpt"
,
subdir
=
dit_quant_scheme
)
self
.
config
.
use_gguf
=
True
self
.
config
[
"
use_gguf
"
]
=
True
else
:
else
:
self
.
dit_quantized_ckpt
=
find_hf_model_path
(
self
.
dit_quantized_ckpt
=
find_hf_model_path
(
config
,
config
,
...
@@ -87,7 +87,7 @@ class WanModel(CompiledMethodsMixin):
...
@@ -87,7 +87,7 @@ class WanModel(CompiledMethodsMixin):
self
.
dit_quantized_ckpt
=
None
self
.
dit_quantized_ckpt
=
None
assert
not
self
.
config
.
get
(
"lazy_load"
,
False
)
assert
not
self
.
config
.
get
(
"lazy_load"
,
False
)
self
.
weight_auto_quant
=
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
self
.
weight_auto_quant
=
self
.
config
[
"
mm_config
"
]
.
get
(
"weight_auto_quant"
,
False
)
if
self
.
dit_quantized
:
if
self
.
dit_quantized
:
assert
self
.
weight_auto_quant
or
self
.
dit_quantized_ckpt
is
not
None
assert
self
.
weight_auto_quant
or
self
.
dit_quantized_ckpt
is
not
None
...
@@ -158,7 +158,7 @@ class WanModel(CompiledMethodsMixin):
...
@@ -158,7 +158,7 @@ class WanModel(CompiledMethodsMixin):
weight_dict
=
{}
weight_dict
=
{}
for
file_path
in
safetensors_files
:
for
file_path
in
safetensors_files
:
if
self
.
config
.
get
(
"adapter_model_path"
,
None
)
is
not
None
:
if
self
.
config
.
get
(
"adapter_model_path"
,
None
)
is
not
None
:
if
self
.
config
.
adapter_model_path
==
file_path
:
if
self
.
config
[
"
adapter_model_path
"
]
==
file_path
:
continue
continue
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
unified_dtype
,
sensitive_layer
)
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
unified_dtype
,
sensitive_layer
)
weight_dict
.
update
(
file_weights
)
weight_dict
.
update
(
file_weights
)
...
@@ -367,7 +367,7 @@ class WanModel(CompiledMethodsMixin):
...
@@ -367,7 +367,7 @@ class WanModel(CompiledMethodsMixin):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
def
infer
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
[
"
model_cls
"
]
:
self
.
to_cuda
()
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
pre_weight
.
to_cuda
()
...
@@ -400,7 +400,7 @@ class WanModel(CompiledMethodsMixin):
...
@@ -400,7 +400,7 @@ class WanModel(CompiledMethodsMixin):
self
.
scheduler
.
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
infer_condition
=
True
)
self
.
scheduler
.
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
infer_condition
=
True
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
[
"
model_cls
"
]
:
self
.
to_cpu
()
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
...
@@ -441,7 +441,7 @@ class WanModel(CompiledMethodsMixin):
...
@@ -441,7 +441,7 @@ class WanModel(CompiledMethodsMixin):
pre_infer_out
.
x
=
torch
.
chunk
(
x
,
world_size
,
dim
=
0
)[
cur_rank
]
pre_infer_out
.
x
=
torch
.
chunk
(
x
,
world_size
,
dim
=
0
)[
cur_rank
]
if
self
.
config
[
"model_cls"
]
in
[
"wan2.2"
,
"wan2.2_audio"
]
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"model_cls"
]
in
[
"wan2.2"
,
"wan2.2_audio"
]
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
embed
,
embed0
=
pre_infer_out
.
embed
,
pre_infer_out
.
embed0
embed
,
embed0
=
pre_infer_out
.
embed
,
pre_infer_out
.
embed0
padding_size
=
(
world_size
-
(
embed
.
shape
[
0
]
%
world_size
))
%
world_size
padding_size
=
(
world_size
-
(
embed
.
shape
[
0
]
%
world_size
))
%
world_size
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment