Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
04812de2
"tests/schedulers/test_scheduler_euler_ancestral.py" did not exist on "6a7a5467cab6df8bb24b20a7ad3f2223c1a2e8de"
Unverified
Commit
04812de2
authored
Sep 29, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Sep 29, 2025
Browse files
Refactor Config System (#338)
parent
6a658f42
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
157 additions
and
109 deletions
+157
-109
configs/wan22/wan_moe_i2v_distill.json
configs/wan22/wan_moe_i2v_distill.json
+10
-3
configs/wan22/wan_moe_i2v_distill_quant.json
configs/wan22/wan_moe_i2v_distill_quant.json
+10
-3
configs/wan22/wan_moe_t2v.json
configs/wan22/wan_moe_t2v.json
+4
-2
configs/wan22/wan_moe_t2v_distill.json
configs/wan22/wan_moe_t2v_distill.json
+10
-3
configs/wan22/wan_ti2v_i2v.json
configs/wan22/wan_ti2v_i2v.json
+5
-2
configs/wan22/wan_ti2v_i2v_4090.json
configs/wan22/wan_ti2v_i2v_4090.json
+5
-2
configs/wan22/wan_ti2v_t2v.json
configs/wan22/wan_ti2v_t2v.json
+5
-2
configs/wan22/wan_ti2v_t2v_4090.json
configs/wan22/wan_ti2v_t2v_4090.json
+5
-2
docs/EN/source/method_tutorials/video_frame_interpolation.md
docs/EN/source/method_tutorials/video_frame_interpolation.md
+3
-3
docs/ZH_CN/source/method_tutorials/video_frame_interpolation.md
...H_CN/source/method_tutorials/video_frame_interpolation.md
+3
-3
lightx2v/deploy/worker/hub.py
lightx2v/deploy/worker/hub.py
+3
-3
lightx2v/infer.py
lightx2v/infer.py
+11
-6
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+25
-17
lightx2v/models/networks/wan/infer/audio/pre_infer.py
lightx2v/models/networks/wan/infer/audio/pre_infer.py
+2
-2
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
...2v/models/networks/wan/infer/causvid/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+32
-32
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+1
-1
lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py
...dels/networks/wan/infer/self_forcing/transformer_infer.py
+5
-5
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+7
-7
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+9
-9
No files found.
configs/wan22/wan_moe_i2v_distill.json
View file @
04812de2
...
...
@@ -7,8 +7,10 @@
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
3.5
,
3.5
],
"sample_guide_scale"
:
[
3.5
,
3.5
],
"sample_shift"
:
5.0
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
...
...
@@ -17,5 +19,10 @@
"vae_cpu_offload"
:
false
,
"use_image_encoder"
:
false
,
"boundary_step_index"
:
2
,
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
]
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
]
}
configs/wan22/wan_moe_i2v_distill_quant.json
View file @
04812de2
...
...
@@ -7,8 +7,10 @@
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
3.5
,
3.5
],
"sample_guide_scale"
:
[
3.5
,
3.5
],
"sample_shift"
:
5.0
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
...
...
@@ -17,7 +19,12 @@
"vae_cpu_offload"
:
false
,
"use_image_encoder"
:
false
,
"boundary_step_index"
:
2
,
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
],
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
],
"mm_config"
:
{
"mm_type"
:
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
},
...
...
configs/wan22/wan_moe_t2v.json
View file @
04812de2
...
...
@@ -7,8 +7,10 @@
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
4.0
,
3.0
],
"sample_guide_scale"
:
[
4.0
,
3.0
],
"sample_shift"
:
12.0
,
"enable_cfg"
:
true
,
"cpu_offload"
:
true
,
...
...
configs/wan22/wan_moe_t2v_distill.json
View file @
04812de2
...
...
@@ -7,8 +7,10 @@
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
4.0
,
3.0
],
"sample_guide_scale"
:
[
4.0
,
3.0
],
"sample_shift"
:
5.0
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
...
...
@@ -16,7 +18,12 @@
"t5_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"boundary_step_index"
:
2
,
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
],
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
],
"lora_configs"
:
[
{
"name"
:
"low_noise_model"
,
...
...
configs/wan22/wan_ti2v_i2v.json
View file @
04812de2
...
...
@@ -5,11 +5,14 @@
"target_height"
:
704
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
...
...
configs/wan22/wan_ti2v_i2v_4090.json
View file @
04812de2
...
...
@@ -5,11 +5,14 @@
"target_height"
:
704
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
...
...
configs/wan22/wan_ti2v_t2v.json
View file @
04812de2
...
...
@@ -5,11 +5,14 @@
"target_height"
:
704
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
...
...
configs/wan22/wan_ti2v_t2v_4090.json
View file @
04812de2
...
...
@@ -5,11 +5,14 @@
"target_height"
:
704
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
...
...
docs/EN/source/method_tutorials/video_frame_interpolation.md
View file @
04812de2
...
...
@@ -67,7 +67,7 @@ python lightx2v/infer.py \
--model_path
/path/to/model
\
--config_json
./configs/video_frame_interpolation/wan_t2v.json
\
--prompt
"A beautiful sunset over the ocean"
\
--save_
video
_path
./output.mp4
--save_
result
_path
./output.mp4
```
### Configuration Parameters
...
...
@@ -136,7 +136,7 @@ python lightx2v/infer.py \
--model_path
./models/wan2.1
\
--config_json
./wan_t2v_vfi_32fps.json
\
--prompt
"A cat playing in the garden"
\
--save_
video
_path
./output_32fps.mp4
--save_
result
_path
./output_32fps.mp4
```
### Higher Frame Rate Enhancement
...
...
@@ -170,7 +170,7 @@ python lightx2v/infer.py \
--config_json
./wan_i2v_vfi_60fps.json
\
--image_path
./input.jpg
\
--prompt
"Smooth camera movement"
\
--save_
video
_path
./output_60fps.mp4
--save_
result
_path
./output_60fps.mp4
```
## Performance Considerations
...
...
docs/ZH_CN/source/method_tutorials/video_frame_interpolation.md
View file @
04812de2
...
...
@@ -67,7 +67,7 @@ python lightx2v/infer.py \
--model_path
/path/to/model
\
--config_json
./configs/video_frame_interpolation/wan_t2v.json
\
--prompt
"美丽的海上日落"
\
--save_
video
_path
./output.mp4
--save_
result
_path
./output.mp4
```
### 配置参数说明
...
...
@@ -136,7 +136,7 @@ python lightx2v/infer.py \
--model_path
./models/wan2.1
\
--config_json
./wan_t2v_vfi_32fps.json
\
--prompt
"一只小猫在花园里玩耍"
\
--save_
video
_path
./output_32fps.mp4
--save_
result
_path
./output_32fps.mp4
```
### 更高帧率增强
...
...
@@ -170,7 +170,7 @@ python lightx2v/infer.py \
--config_json
./wan_i2v_vfi_60fps.json
\
--image_path
./input.jpg
\
--prompt
"平滑的相机运动"
\
--save_
video
_path
./output_60fps.mp4
--save_
result
_path
./output_60fps.mp4
```
## 性能考虑
...
...
lightx2v/deploy/worker/hub.py
View file @
04812de2
...
...
@@ -23,7 +23,7 @@ from lightx2v.utils.utils import seed_all
class
BaseWorker
:
@
ProfilingContext4DebugL1
(
"Init Worker Worker Cost:"
)
def
__init__
(
self
,
args
):
args
.
save_
video
_path
=
""
args
.
save_
result
_path
=
""
config
=
set_config
(
args
)
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
seed_all
(
config
.
seed
)
...
...
@@ -49,7 +49,7 @@ class BaseWorker:
self
.
runner
.
config
[
"prompt"
]
=
params
[
"prompt"
]
self
.
runner
.
config
[
"negative_prompt"
]
=
params
.
get
(
"negative_prompt"
,
""
)
self
.
runner
.
config
[
"image_path"
]
=
params
.
get
(
"image_path"
,
""
)
self
.
runner
.
config
[
"save_
video
_path"
]
=
params
.
get
(
"save_
video
_path"
,
""
)
self
.
runner
.
config
[
"save_
result
_path"
]
=
params
.
get
(
"save_
result
_path"
,
""
)
self
.
runner
.
config
[
"seed"
]
=
params
.
get
(
"seed"
,
self
.
fixed_config
.
get
(
"seed"
,
42
))
self
.
runner
.
config
[
"audio_path"
]
=
params
.
get
(
"audio_path"
,
""
)
...
...
@@ -92,7 +92,7 @@ class BaseWorker:
if
stream_video_path
is
not
None
:
tmp_video_path
=
stream_video_path
params
[
"save_
video
_path"
]
=
tmp_video_path
params
[
"save_
result
_path"
]
=
tmp_video_path
return
tmp_video_path
,
output_video_path
async
def
prepare_dit_inputs
(
self
,
inputs
,
data_manager
):
...
...
lightx2v/infer.py
View file @
04812de2
...
...
@@ -17,6 +17,7 @@ from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401
from
lightx2v.models.runners.wan.wan_skyreels_v2_df_runner
import
WanSkyreelsV2DFRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_vace_runner
import
WanVaceRunner
# noqa: F401
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.input_info
import
set_input_info
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
...
...
@@ -24,15 +25,15 @@ from lightx2v.utils.utils import seed_all
def
init_runner
(
config
):
seed_all
(
config
.
seed
)
torch
.
set_grad_enabled
(
False
)
runner
=
RUNNER_REGISTER
[
config
.
model_cls
](
config
)
runner
=
RUNNER_REGISTER
[
config
[
"
model_cls
"
]
](
config
)
runner
.
init_modules
()
return
runner
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"The seed for random generator"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
...
...
@@ -58,7 +59,7 @@ def main():
default
=
"wan2.1"
,
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
,
"t2i"
,
"i2i"
,
"flf2v"
,
"vace"
,
"animate"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
,
"t2i"
,
"i2i"
,
"flf2v"
,
"vace"
,
"animate"
,
"s2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--sf_model_path"
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
...
...
@@ -91,13 +92,16 @@ def main():
help
=
"The file of the source mask. Default None."
,
)
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
None
,
help
=
"The path to save video path/file"
)
parser
.
add_argument
(
"--save_result_path"
,
type
=
str
,
default
=
None
,
help
=
"The path to save video path/file"
)
parser
.
add_argument
(
"--return_result_tensor"
,
action
=
"store_true"
,
help
=
"Whether to return result tensor. (Useful for comfyui)"
)
args
=
parser
.
parse_args
()
seed_all
(
args
.
seed
)
# set config
config
=
set_config
(
args
)
if
config
.
parallel
:
if
config
[
"
parallel
"
]
:
dist
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
set_parallel_config
(
config
)
...
...
@@ -106,7 +110,8 @@ def main():
with
ProfilingContext4DebugL1
(
"Total Cost"
):
runner
=
init_runner
(
config
)
runner
.
run_pipeline
()
input_info
=
set_input_info
(
args
)
runner
.
run_pipeline
(
input_info
)
# Clean up distributed process group
if
dist
.
is_initialized
():
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
04812de2
...
...
@@ -2,6 +2,7 @@ import os
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
loguru
import
logger
from
lightx2v.models.networks.wan.infer.audio.post_infer
import
WanAudioPostInfer
...
...
@@ -35,11 +36,11 @@ class WanAudioModel(WanModel):
raise
ValueError
(
f
"Unsupported quant_scheme:
{
self
.
config
.
get
(
'adapter_quant_scheme'
,
None
)
}
"
)
else
:
adapter_model_name
=
"audio_adapter_model.safetensors"
self
.
config
.
adapter_model_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
adapter_model_name
)
self
.
config
[
"
adapter_model_path
"
]
=
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
adapter_model_name
)
adapter_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
load_from_rank0
=
self
.
config
.
get
(
"load_from_rank0"
,
False
)
self
.
adapter_weights_dict
=
load_weights
(
self
.
config
.
adapter_model_path
,
cpu_offload
=
adapter_offload
,
remove_key
=
"audio"
,
load_from_rank0
=
load_from_rank0
)
self
.
adapter_weights_dict
=
load_weights
(
self
.
config
[
"
adapter_model_path
"
]
,
cpu_offload
=
adapter_offload
,
remove_key
=
"audio"
,
load_from_rank0
=
load_from_rank0
)
if
not
adapter_offload
:
if
not
dist
.
is_initialized
()
or
not
load_from_rank0
:
for
key
in
self
.
adapter_weights_dict
:
...
...
@@ -51,17 +52,17 @@ class WanAudioModel(WanModel):
self
.
post_infer_class
=
WanAudioPostInfer
self
.
transformer_infer_class
=
WanAudioTransformerInfer
def
get_graph_name
(
self
,
shape
,
audio_num
):
return
f
"graph_
{
shape
[
0
]
}
x
{
shape
[
1
]
}
_
{
audio_num
}
audio"
def
get_graph_name
(
self
,
shape
,
audio_num
,
with_mask
=
True
):
return
f
"graph_
{
shape
[
0
]
}
x
{
shape
[
1
]
}
_audio_num
_
{
audio
_num
}
_mask_
{
with_mask
}
"
def
start_compile
(
self
,
shape
,
audio_num
):
graph_name
=
self
.
get_graph_name
(
shape
,
audio_num
)
def
start_compile
(
self
,
shape
,
audio_num
,
with_mask
=
True
):
graph_name
=
self
.
get_graph_name
(
shape
,
audio_num
,
with_mask
)
logger
.
info
(
f
"[Compile] Compile shape:
{
shape
}
, audio_num:
{
audio_num
}
, graph_name:
{
graph_name
}
"
)
target_video_length
=
self
.
config
.
get
(
"target_video_length"
,
81
)
latents_length
=
(
target_video_length
-
1
)
//
16
*
4
+
1
latents_h
=
shape
[
0
]
//
self
.
config
.
vae_stride
[
1
]
latents_w
=
shape
[
1
]
//
self
.
config
.
vae_stride
[
2
]
latents_h
=
shape
[
0
]
//
self
.
config
[
"
vae_stride
"
]
[
1
]
latents_w
=
shape
[
1
]
//
self
.
config
[
"
vae_stride
"
]
[
2
]
new_inputs
=
{}
new_inputs
[
"text_encoder_output"
]
=
{}
...
...
@@ -73,7 +74,11 @@ class WanAudioModel(WanModel):
new_inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
torch
.
randn
(
16
,
1
,
latents_h
,
latents_w
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"audio_encoder_output"
]
=
torch
.
randn
(
audio_num
,
latents_length
,
128
,
1024
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"person_mask_latens"
]
=
torch
.
zeros
(
audio_num
,
1
,
(
latents_h
//
2
),
(
latents_w
//
2
),
dtype
=
torch
.
int8
).
cuda
()
if
with_mask
:
new_inputs
[
"person_mask_latens"
]
=
torch
.
zeros
(
audio_num
,
1
,
(
latents_h
//
2
),
(
latents_w
//
2
),
dtype
=
torch
.
int8
).
cuda
()
else
:
assert
audio_num
==
1
,
"audio_num must be 1 when with_mask is False"
new_inputs
[
"person_mask_latens"
]
=
None
new_inputs
[
"previmg_encoder_output"
]
=
{}
new_inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
=
torch
.
randn
(
16
,
latents_length
,
latents_h
,
latents_w
,
dtype
=
torch
.
bfloat16
).
cuda
()
...
...
@@ -90,19 +95,21 @@ class WanAudioModel(WanModel):
self
.
enable_compile_mode
(
"_infer_cond_uncond"
)
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
[
"
model_cls
"
]
:
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
transformer_weights
.
non_block_weights_to_cuda
()
max_audio_num_num
=
self
.
config
.
get
(
"compile_max_audios"
,
1
)
max_audio_num_num
=
self
.
config
.
get
(
"compile_max_audios"
,
3
)
for
audio_num
in
range
(
1
,
max_audio_num_num
+
1
):
for
shape
in
compile_shapes
:
self
.
start_compile
(
shape
,
audio_num
)
self
.
start_compile
(
shape
,
audio_num
,
with_mask
=
True
)
if
audio_num
==
1
:
self
.
start_compile
(
shape
,
audio_num
,
with_mask
=
False
)
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
[
"
model_cls
"
]
:
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
...
...
@@ -115,9 +122,10 @@ class WanAudioModel(WanModel):
for
shape
in
compile_shapes
:
assert
shape
in
[[
480
,
832
],
[
544
,
960
],
[
720
,
1280
],
[
832
,
480
],
[
960
,
544
],
[
1280
,
720
],
[
480
,
480
],
[
576
,
576
],
[
704
,
704
],
[
960
,
960
]]
def
select_graph_for_compile
(
self
):
logger
.
info
(
f
"tgt_h, tgt_w :
{
self
.
config
.
get
(
'tgt_h'
)
}
,
{
self
.
config
.
get
(
'tgt_w'
)
}
, audio_num:
{
self
.
config
.
get
(
'audio_num'
)
}
"
)
self
.
select_graph
(
"_infer_cond_uncond"
,
f
"graph_
{
self
.
config
.
get
(
'tgt_h'
)
}
x
{
self
.
config
.
get
(
'tgt_w'
)
}
_
{
self
.
config
.
get
(
'audio_num'
)
}
audio"
)
def
select_graph_for_compile
(
self
,
input_info
):
logger
.
info
(
f
"target_h, target_w :
{
input_info
.
target_shape
[
0
]
}
,
{
input_info
.
target_shape
[
1
]
}
, audio_num:
{
input_info
.
audio_num
}
"
)
graph_name
=
self
.
get_graph_name
(
input_info
.
target_shape
,
input_info
.
audio_num
,
with_mask
=
input_info
.
with_mask
)
self
.
select_graph
(
"_infer_cond_uncond"
,
graph_name
)
logger
.
info
(
f
"[Compile] Compile status:
{
self
.
get_compile_status
()
}
"
)
@
torch
.
no_grad
()
...
...
@@ -138,7 +146,7 @@ class WanAudioModel(WanModel):
if
person_mask_latens
is
not
None
:
pre_infer_out
.
adapter_args
[
"person_mask_latens"
]
=
torch
.
chunk
(
person_mask_latens
,
world_size
,
dim
=
1
)[
cur_rank
]
if
self
.
config
[
"model_cls"
]
in
[
"wan2.2"
,
"wan2.2_audio"
]
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"model_cls"
]
in
[
"wan2.2"
,
"wan2.2_audio"
]
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
embed
,
embed0
=
pre_infer_out
.
embed
,
pre_infer_out
.
embed0
padding_size
=
(
world_size
-
(
embed
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
...
...
lightx2v/models/networks/wan/infer/audio/pre_infer.py
View file @
04812de2
...
...
@@ -33,7 +33,7 @@ class WanAudioPreInfer(WanPreInfer):
infer_condition
,
latents
,
timestep_input
=
self
.
scheduler
.
infer_condition
,
self
.
scheduler
.
latents
,
self
.
scheduler
.
timestep_input
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
hidden_states
=
latents
if
self
.
config
.
model_cls
!=
"wan2.2_audio"
:
if
self
.
config
[
"
model_cls
"
]
!=
"wan2.2_audio"
:
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
],
dim
=
0
)
...
...
@@ -101,7 +101,7 @@ class WanAudioPreInfer(WanPreInfer):
del
out
torch
.
cuda
.
empty_cache
()
if
self
.
task
==
"i2v"
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_clip
=
weights
.
proj_0
.
apply
(
clip_fea
)
if
self
.
clean_cuda_cache
:
del
clip_fea
...
...
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
View file @
04812de2
...
...
@@ -140,7 +140,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
def
infer_cross_attn
(
self
,
weights
,
x
,
context
,
block_idx
):
norm3_out
=
weights
.
norm3
.
apply
(
x
)
if
self
.
task
==
"i2v"
:
if
self
.
task
in
[
"i2v"
,
"s2v"
]
:
context_img
=
context
[:
257
]
context
=
context
[
257
:]
...
...
@@ -169,7 +169,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
model_cls
=
self
.
config
[
"model_cls"
],
)
if
self
.
task
==
"i2v"
:
if
self
.
task
in
[
"i2v"
,
"s2v"
]
:
k_img
=
weights
.
cross_attn_norm_k_img
.
apply
(
weights
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
v_img
=
weights
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
...
...
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
04812de2
...
...
@@ -25,22 +25,22 @@ class WanTransformerInferCaching(WanOffloadTransformerInfer):
class
WanTransformerInferTeaCaching
(
WanTransformerInferCaching
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
teacache_thresh
=
config
.
teacache_thresh
self
.
teacache_thresh
=
config
[
"
teacache_thresh
"
]
self
.
accumulated_rel_l1_distance_even
=
0
self
.
previous_e0_even
=
None
self
.
previous_residual_even
=
None
self
.
accumulated_rel_l1_distance_odd
=
0
self
.
previous_e0_odd
=
None
self
.
previous_residual_odd
=
None
self
.
use_ret_steps
=
config
.
use_ret_steps
self
.
use_ret_steps
=
config
[
"
use_ret_steps
"
]
if
self
.
use_ret_steps
:
self
.
coefficients
=
self
.
config
.
coefficients
[
0
]
self
.
coefficients
=
self
.
config
[
"
coefficients
"
]
[
0
]
self
.
ret_steps
=
5
self
.
cutoff_steps
=
self
.
config
.
infer_steps
self
.
cutoff_steps
=
self
.
config
[
"
infer_steps
"
]
else
:
self
.
coefficients
=
self
.
config
.
coefficients
[
1
]
self
.
coefficients
=
self
.
config
[
"
coefficients
"
]
[
1
]
self
.
ret_steps
=
1
self
.
cutoff_steps
=
self
.
config
.
infer_steps
-
1
self
.
cutoff_steps
=
self
.
config
[
"
infer_steps
"
]
-
1
# calculate should_calc
@
torch
.
no_grad
()
...
...
@@ -216,7 +216,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInferCaching, BaseTaylorCac
else
:
x
=
self
.
infer_using_cache
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
return
x
...
...
@@ -353,7 +353,7 @@ class WanTransformerInferAdaCaching(WanTransformerInferCaching):
else
:
x
=
self
.
infer_using_cache
(
xt
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
return
x
...
...
@@ -515,7 +515,7 @@ class AdaArgs:
# Moreg related attributes
self
.
previous_moreg
=
1.0
self
.
moreg_strides
=
[
1
]
self
.
moreg_steps
=
[
int
(
0.1
*
config
.
infer_steps
),
int
(
0.9
*
config
.
infer_steps
)]
self
.
moreg_steps
=
[
int
(
0.1
*
config
[
"
infer_steps
"
]
),
int
(
0.9
*
config
[
"
infer_steps
"
]
)]
self
.
moreg_hyp
=
[
0.385
,
8
,
1
,
2
]
self
.
mograd_mul
=
10
self
.
spatial_dim
=
1536
...
...
@@ -525,7 +525,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
cnt
=
0
self
.
teacache_thresh
=
config
.
teacache_thresh
self
.
teacache_thresh
=
config
[
"
teacache_thresh
"
]
self
.
accumulated_rel_l1_distance_even
=
0
self
.
previous_e0_even
=
None
self
.
previous_residual_even
=
None
...
...
@@ -534,15 +534,15 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
self
.
previous_residual_odd
=
None
self
.
cache_even
=
{}
self
.
cache_odd
=
{}
self
.
use_ret_steps
=
config
.
use_ret_steps
self
.
use_ret_steps
=
config
[
"
use_ret_steps
"
]
if
self
.
use_ret_steps
:
self
.
coefficients
=
self
.
config
.
coefficients
[
0
]
self
.
coefficients
=
self
.
config
[
"
coefficients
"
]
[
0
]
self
.
ret_steps
=
5
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
self
.
cutoff_steps
=
self
.
config
[
"
infer_steps
"
]
*
2
else
:
self
.
coefficients
=
self
.
config
.
coefficients
[
1
]
self
.
coefficients
=
self
.
config
[
"
coefficients
"
]
[
1
]
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
self
.
cutoff_steps
=
self
.
config
[
"
infer_steps
"
]
*
2
-
2
# 1. get taylor step_diff when there is two caching_records in scheduler
def
get_taylor_step_diff
(
self
):
...
...
@@ -625,7 +625,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
else
:
x
=
self
.
infer_using_cache
(
x
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
self
.
cnt
+=
1
...
...
@@ -690,12 +690,12 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
residual_diff_threshold
=
config
.
residual_diff_threshold
self
.
residual_diff_threshold
=
config
[
"
residual_diff_threshold
"
]
self
.
prev_first_block_residual_even
=
None
self
.
prev_remaining_blocks_residual_even
=
None
self
.
prev_first_block_residual_odd
=
None
self
.
prev_remaining_blocks_residual_odd
=
None
self
.
downsample_factor
=
self
.
config
.
downsample_factor
self
.
downsample_factor
=
self
.
config
[
"
downsample_factor
"
]
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
ori_x
=
x
.
clone
()
...
...
@@ -727,7 +727,7 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching):
else
:
x
=
self
.
infer_using_cache
(
x
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
return
x
...
...
@@ -795,12 +795,12 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
residual_diff_threshold
=
config
.
residual_diff_threshold
self
.
residual_diff_threshold
=
config
[
"
residual_diff_threshold
"
]
self
.
prev_front_blocks_residual_even
=
None
self
.
prev_middle_blocks_residual_even
=
None
self
.
prev_front_blocks_residual_odd
=
None
self
.
prev_middle_blocks_residual_odd
=
None
self
.
downsample_factor
=
self
.
config
.
downsample_factor
self
.
downsample_factor
=
self
.
config
[
"
downsample_factor
"
]
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
ori_x
=
x
.
clone
()
...
...
@@ -854,7 +854,7 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
context
,
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
return
x
...
...
@@ -921,8 +921,8 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
class
WanTransformerInferDynamicBlock
(
WanTransformerInferCaching
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
residual_diff_threshold
=
config
.
residual_diff_threshold
self
.
downsample_factor
=
self
.
config
.
downsample_factor
self
.
residual_diff_threshold
=
config
[
"
residual_diff_threshold
"
]
self
.
downsample_factor
=
self
.
config
[
"
downsample_factor
"
]
self
.
block_in_cache_even
=
{
i
:
None
for
i
in
range
(
self
.
blocks_num
)}
self
.
block_residual_cache_even
=
{
i
:
None
for
i
in
range
(
self
.
blocks_num
)}
...
...
@@ -992,10 +992,10 @@ class WanTransformerInferDynamicBlock(WanTransformerInferCaching):
class
WanTransformerInferMagCaching
(
WanTransformerInferCaching
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
magcache_thresh
=
config
.
magcache_thresh
self
.
K
=
config
.
magcache_K
self
.
retention_ratio
=
config
.
magcache_retention_ratio
self
.
mag_ratios
=
np
.
array
(
config
.
magcache_ratios
)
self
.
magcache_thresh
=
config
[
"
magcache_thresh
"
]
self
.
K
=
config
[
"
magcache_K
"
]
self
.
retention_ratio
=
config
[
"
magcache_retention_ratio
"
]
self
.
mag_ratios
=
np
.
array
(
config
[
"
magcache_ratios
"
]
)
# {True: cond_param, False: uncond_param}
self
.
accumulated_err
=
{
True
:
0.0
,
False
:
0.0
}
self
.
accumulated_steps
=
{
True
:
0
,
False
:
0
}
...
...
@@ -1011,10 +1011,10 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
step_index
=
self
.
scheduler
.
step_index
infer_condition
=
self
.
scheduler
.
infer_condition
if
self
.
config
.
magcache_calibration
:
if
self
.
config
[
"
magcache_calibration
"
]
:
skip_forward
=
False
else
:
if
step_index
>=
int
(
self
.
config
.
infer_steps
*
self
.
retention_ratio
):
if
step_index
>=
int
(
self
.
config
[
"
infer_steps
"
]
*
self
.
retention_ratio
):
# conditional and unconditional in one list
cur_mag_ratio
=
self
.
mag_ratios
[
0
][
step_index
]
if
infer_condition
else
self
.
mag_ratios
[
1
][
step_index
]
# magnitude ratio between current step and the cached step
...
...
@@ -1054,7 +1054,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
if
self
.
config
[
"cpu_offload"
]:
previous_residual
=
previous_residual
.
cpu
()
if
self
.
config
.
magcache_calibration
and
step_index
>=
1
:
if
self
.
config
[
"
magcache_calibration
"
]
and
step_index
>=
1
:
norm_ratio
=
((
previous_residual
.
norm
(
dim
=-
1
)
/
self
.
residual_cache
[
infer_condition
].
norm
(
dim
=-
1
)).
mean
()).
item
()
norm_std
=
(
previous_residual
.
norm
(
dim
=-
1
)
/
self
.
residual_cache
[
infer_condition
].
norm
(
dim
=-
1
)).
std
().
item
()
cos_dis
=
(
1
-
F
.
cosine_similarity
(
previous_residual
,
self
.
residual_cache
[
infer_condition
],
dim
=-
1
,
eps
=
1e-8
)).
mean
().
item
()
...
...
@@ -1083,7 +1083,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
self
.
accumulated_steps
=
{
True
:
0
,
False
:
0
}
self
.
accumulated_ratio
=
{
True
:
1.0
,
False
:
1.0
}
self
.
residual_cache
=
{
True
:
None
,
False
:
None
}
if
self
.
config
.
magcache_calibration
:
if
self
.
config
[
"
magcache_calibration
"
]
:
print
(
"norm ratio"
)
print
(
self
.
norm_ratio
)
print
(
"norm std"
)
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
04812de2
...
...
@@ -41,7 +41,7 @@ class WanPreInfer:
else
:
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]:
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]:
if
self
.
config
.
get
(
"use_image_encoder"
,
True
):
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
...
...
lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py
View file @
04812de2
...
...
@@ -39,12 +39,12 @@ class WanSFTransformerInfer(WanTransformerInfer):
else
:
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
dtype
=
torch
.
bfloat16
sf_config
=
self
.
config
.
sf_config
self
.
local_attn_size
=
sf_config
.
local_attn_size
sf_config
=
self
.
config
[
"
sf_config
"
]
self
.
local_attn_size
=
sf_config
[
"
local_attn_size
"
]
self
.
max_attention_size
=
32760
if
self
.
local_attn_size
==
-
1
else
self
.
local_attn_size
*
1560
self
.
num_frame_per_block
=
sf_config
.
num_frame_per_block
self
.
num_transformer_blocks
=
sf_config
.
num_transformer_blocks
self
.
frame_seq_length
=
sf_config
.
frame_seq_length
self
.
num_frame_per_block
=
sf_config
[
"
num_frame_per_block
"
]
self
.
num_transformer_blocks
=
sf_config
[
"
num_transformer_blocks
"
]
self
.
frame_seq_length
=
sf_config
[
"
frame_seq_length
"
]
self
.
_initialize_kv_cache
(
self
.
device
,
self
.
dtype
)
self
.
_initialize_crossattn_cache
(
self
.
device
,
self
.
dtype
)
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
04812de2
...
...
@@ -11,13 +11,13 @@ from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, comp
class
WanTransformerInfer
(
BaseTransformerInfer
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
task
=
config
.
task
self
.
task
=
config
[
"
task
"
]
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
self
.
blocks_num
=
config
.
num_layers
self
.
blocks_num
=
config
[
"
num_layers
"
]
self
.
phases_num
=
3
self
.
has_post_adapter
=
False
self
.
num_heads
=
config
.
num_heads
self
.
head_dim
=
config
.
dim
//
config
.
num_heads
self
.
num_heads
=
config
[
"
num_heads
"
]
self
.
head_dim
=
config
[
"
dim
"
]
//
config
[
"
num_heads
"
]
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
parallel_attention
=
None
if
config
.
get
(
"rotary_chunk"
,
False
):
...
...
@@ -203,7 +203,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x
.
add_
(
y_out
*
gate_msa
.
squeeze
())
norm3_out
=
phase
.
norm3
.
apply
(
x
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_img
=
context
[:
257
]
context
=
context
[
257
:]
else
:
...
...
@@ -211,7 +211,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
context
=
context
.
to
(
self
.
infer_dtype
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_img
=
context_img
.
to
(
self
.
infer_dtype
)
n
,
d
=
self
.
num_heads
,
self
.
head_dim
...
...
@@ -234,7 +234,7 @@ class WanTransformerInfer(BaseTransformerInfer):
model_cls
=
self
.
config
[
"model_cls"
],
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
)
and
context_img
is
not
None
:
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
)
and
context_img
is
not
None
:
k_img
=
phase
.
cross_attn_norm_k_img
.
apply
(
phase
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
v_img
=
phase
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
...
...
lightx2v/models/networks/wan/model.py
View file @
04812de2
...
...
@@ -62,15 +62,15 @@ class WanModel(CompiledMethodsMixin):
self
.
init_empty_model
=
False
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
self
.
dit_quantized
=
self
.
config
[
"
mm_config
"
]
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
if
self
.
dit_quantized
:
dit_quant_scheme
=
self
.
config
.
mm_config
.
get
(
"mm_type"
).
split
(
"-"
)[
1
]
if
self
.
config
.
model_cls
==
"wan2.1_distill"
:
dit_quant_scheme
=
self
.
config
[
"
mm_config
"
]
.
get
(
"mm_type"
).
split
(
"-"
)[
1
]
if
self
.
config
[
"
model_cls
"
]
==
"wan2.1_distill"
:
dit_quant_scheme
=
"distill_"
+
dit_quant_scheme
if
dit_quant_scheme
==
"gguf"
:
self
.
dit_quantized_ckpt
=
find_gguf_model_path
(
config
,
"dit_quantized_ckpt"
,
subdir
=
dit_quant_scheme
)
self
.
config
.
use_gguf
=
True
self
.
config
[
"
use_gguf
"
]
=
True
else
:
self
.
dit_quantized_ckpt
=
find_hf_model_path
(
config
,
...
...
@@ -87,7 +87,7 @@ class WanModel(CompiledMethodsMixin):
self
.
dit_quantized_ckpt
=
None
assert
not
self
.
config
.
get
(
"lazy_load"
,
False
)
self
.
weight_auto_quant
=
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
self
.
weight_auto_quant
=
self
.
config
[
"
mm_config
"
]
.
get
(
"weight_auto_quant"
,
False
)
if
self
.
dit_quantized
:
assert
self
.
weight_auto_quant
or
self
.
dit_quantized_ckpt
is
not
None
...
...
@@ -158,7 +158,7 @@ class WanModel(CompiledMethodsMixin):
weight_dict
=
{}
for
file_path
in
safetensors_files
:
if
self
.
config
.
get
(
"adapter_model_path"
,
None
)
is
not
None
:
if
self
.
config
.
adapter_model_path
==
file_path
:
if
self
.
config
[
"
adapter_model_path
"
]
==
file_path
:
continue
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
unified_dtype
,
sensitive_layer
)
weight_dict
.
update
(
file_weights
)
...
...
@@ -367,7 +367,7 @@ class WanModel(CompiledMethodsMixin):
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
[
"
model_cls
"
]
:
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
...
...
@@ -400,7 +400,7 @@ class WanModel(CompiledMethodsMixin):
self
.
scheduler
.
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
infer_condition
=
True
)
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
[
"
model_cls
"
]
:
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
...
...
@@ -441,7 +441,7 @@ class WanModel(CompiledMethodsMixin):
pre_infer_out
.
x
=
torch
.
chunk
(
x
,
world_size
,
dim
=
0
)[
cur_rank
]
if
self
.
config
[
"model_cls"
]
in
[
"wan2.2"
,
"wan2.2_audio"
]
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"model_cls"
]
in
[
"wan2.2"
,
"wan2.2_audio"
]
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
embed
,
embed0
=
pre_infer_out
.
embed
,
pre_infer_out
.
embed0
padding_size
=
(
world_size
-
(
embed
.
shape
[
0
]
%
world_size
))
%
world_size
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment