Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
04812de2
Unverified
Commit
04812de2
authored
Sep 29, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Sep 29, 2025
Browse files
Refactor Config System (#338)
parent
6a658f42
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
157 additions
and
109 deletions
+157
-109
configs/wan22/wan_moe_i2v_distill.json
configs/wan22/wan_moe_i2v_distill.json
+10
-3
configs/wan22/wan_moe_i2v_distill_quant.json
configs/wan22/wan_moe_i2v_distill_quant.json
+10
-3
configs/wan22/wan_moe_t2v.json
configs/wan22/wan_moe_t2v.json
+4
-2
configs/wan22/wan_moe_t2v_distill.json
configs/wan22/wan_moe_t2v_distill.json
+10
-3
configs/wan22/wan_ti2v_i2v.json
configs/wan22/wan_ti2v_i2v.json
+5
-2
configs/wan22/wan_ti2v_i2v_4090.json
configs/wan22/wan_ti2v_i2v_4090.json
+5
-2
configs/wan22/wan_ti2v_t2v.json
configs/wan22/wan_ti2v_t2v.json
+5
-2
configs/wan22/wan_ti2v_t2v_4090.json
configs/wan22/wan_ti2v_t2v_4090.json
+5
-2
docs/EN/source/method_tutorials/video_frame_interpolation.md
docs/EN/source/method_tutorials/video_frame_interpolation.md
+3
-3
docs/ZH_CN/source/method_tutorials/video_frame_interpolation.md
...H_CN/source/method_tutorials/video_frame_interpolation.md
+3
-3
lightx2v/deploy/worker/hub.py
lightx2v/deploy/worker/hub.py
+3
-3
lightx2v/infer.py
lightx2v/infer.py
+11
-6
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+25
-17
lightx2v/models/networks/wan/infer/audio/pre_infer.py
lightx2v/models/networks/wan/infer/audio/pre_infer.py
+2
-2
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
...2v/models/networks/wan/infer/causvid/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+32
-32
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+1
-1
lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py
...dels/networks/wan/infer/self_forcing/transformer_infer.py
+5
-5
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+7
-7
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+9
-9
No files found.
configs/wan22/wan_moe_i2v_distill.json
View file @
04812de2
...
...
@@ -7,8 +7,10 @@
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
3.5
,
3.5
],
"sample_guide_scale"
:
[
3.5
,
3.5
],
"sample_shift"
:
5.0
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
...
...
@@ -17,5 +19,10 @@
"vae_cpu_offload"
:
false
,
"use_image_encoder"
:
false
,
"boundary_step_index"
:
2
,
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
]
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
]
}
configs/wan22/wan_moe_i2v_distill_quant.json
View file @
04812de2
...
...
@@ -7,8 +7,10 @@
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
3.5
,
3.5
],
"sample_guide_scale"
:
[
3.5
,
3.5
],
"sample_shift"
:
5.0
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
...
...
@@ -17,7 +19,12 @@
"vae_cpu_offload"
:
false
,
"use_image_encoder"
:
false
,
"boundary_step_index"
:
2
,
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
],
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
],
"mm_config"
:
{
"mm_type"
:
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
},
...
...
configs/wan22/wan_moe_t2v.json
View file @
04812de2
...
...
@@ -7,8 +7,10 @@
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
4.0
,
3.0
],
"sample_guide_scale"
:
[
4.0
,
3.0
],
"sample_shift"
:
12.0
,
"enable_cfg"
:
true
,
"cpu_offload"
:
true
,
...
...
configs/wan22/wan_moe_t2v_distill.json
View file @
04812de2
...
...
@@ -7,8 +7,10 @@
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
4.0
,
3.0
],
"sample_guide_scale"
:
[
4.0
,
3.0
],
"sample_shift"
:
5.0
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
...
...
@@ -16,7 +18,12 @@
"t5_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"boundary_step_index"
:
2
,
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
],
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
],
"lora_configs"
:
[
{
"name"
:
"low_noise_model"
,
...
...
configs/wan22/wan_ti2v_i2v.json
View file @
04812de2
...
...
@@ -5,11 +5,14 @@
"target_height"
:
704
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
...
...
configs/wan22/wan_ti2v_i2v_4090.json
View file @
04812de2
...
...
@@ -5,11 +5,14 @@
"target_height"
:
704
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
...
...
configs/wan22/wan_ti2v_t2v.json
View file @
04812de2
...
...
@@ -5,11 +5,14 @@
"target_height"
:
704
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
...
...
configs/wan22/wan_ti2v_t2v_4090.json
View file @
04812de2
...
...
@@ -5,11 +5,14 @@
"target_height"
:
704
,
"target_width"
:
1280
,
"num_channels_latents"
:
48
,
"vae_stride"
:
[
4
,
16
,
16
],
"vae_stride"
:
[
4
,
16
,
16
],
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
true
,
...
...
docs/EN/source/method_tutorials/video_frame_interpolation.md
View file @
04812de2
...
...
@@ -67,7 +67,7 @@ python lightx2v/infer.py \
--model_path
/path/to/model
\
--config_json
./configs/video_frame_interpolation/wan_t2v.json
\
--prompt
"A beautiful sunset over the ocean"
\
--save_
video
_path
./output.mp4
--save_
result
_path
./output.mp4
```
### Configuration Parameters
...
...
@@ -136,7 +136,7 @@ python lightx2v/infer.py \
--model_path
./models/wan2.1
\
--config_json
./wan_t2v_vfi_32fps.json
\
--prompt
"A cat playing in the garden"
\
--save_
video
_path
./output_32fps.mp4
--save_
result
_path
./output_32fps.mp4
```
### Higher Frame Rate Enhancement
...
...
@@ -170,7 +170,7 @@ python lightx2v/infer.py \
--config_json
./wan_i2v_vfi_60fps.json
\
--image_path
./input.jpg
\
--prompt
"Smooth camera movement"
\
--save_
video
_path
./output_60fps.mp4
--save_
result
_path
./output_60fps.mp4
```
## Performance Considerations
...
...
docs/ZH_CN/source/method_tutorials/video_frame_interpolation.md
View file @
04812de2
...
...
@@ -67,7 +67,7 @@ python lightx2v/infer.py \
--model_path
/path/to/model
\
--config_json
./configs/video_frame_interpolation/wan_t2v.json
\
--prompt
"美丽的海上日落"
\
--save_
video
_path
./output.mp4
--save_
result
_path
./output.mp4
```
### 配置参数说明
...
...
@@ -136,7 +136,7 @@ python lightx2v/infer.py \
--model_path
./models/wan2.1
\
--config_json
./wan_t2v_vfi_32fps.json
\
--prompt
"一只小猫在花园里玩耍"
\
--save_
video
_path
./output_32fps.mp4
--save_
result
_path
./output_32fps.mp4
```
### 更高帧率增强
...
...
@@ -170,7 +170,7 @@ python lightx2v/infer.py \
--config_json
./wan_i2v_vfi_60fps.json
\
--image_path
./input.jpg
\
--prompt
"平滑的相机运动"
\
--save_
video
_path
./output_60fps.mp4
--save_
result
_path
./output_60fps.mp4
```
## 性能考虑
...
...
lightx2v/deploy/worker/hub.py
View file @
04812de2
...
...
@@ -23,7 +23,7 @@ from lightx2v.utils.utils import seed_all
class
BaseWorker
:
@
ProfilingContext4DebugL1
(
"Init Worker Worker Cost:"
)
def
__init__
(
self
,
args
):
args
.
save_
video
_path
=
""
args
.
save_
result
_path
=
""
config
=
set_config
(
args
)
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
seed_all
(
config
.
seed
)
...
...
@@ -49,7 +49,7 @@ class BaseWorker:
self
.
runner
.
config
[
"prompt"
]
=
params
[
"prompt"
]
self
.
runner
.
config
[
"negative_prompt"
]
=
params
.
get
(
"negative_prompt"
,
""
)
self
.
runner
.
config
[
"image_path"
]
=
params
.
get
(
"image_path"
,
""
)
self
.
runner
.
config
[
"save_
video
_path"
]
=
params
.
get
(
"save_
video
_path"
,
""
)
self
.
runner
.
config
[
"save_
result
_path"
]
=
params
.
get
(
"save_
result
_path"
,
""
)
self
.
runner
.
config
[
"seed"
]
=
params
.
get
(
"seed"
,
self
.
fixed_config
.
get
(
"seed"
,
42
))
self
.
runner
.
config
[
"audio_path"
]
=
params
.
get
(
"audio_path"
,
""
)
...
...
@@ -92,7 +92,7 @@ class BaseWorker:
if
stream_video_path
is
not
None
:
tmp_video_path
=
stream_video_path
params
[
"save_
video
_path"
]
=
tmp_video_path
params
[
"save_
result
_path"
]
=
tmp_video_path
return
tmp_video_path
,
output_video_path
async
def
prepare_dit_inputs
(
self
,
inputs
,
data_manager
):
...
...
lightx2v/infer.py
View file @
04812de2
...
...
@@ -17,6 +17,7 @@ from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401
from
lightx2v.models.runners.wan.wan_skyreels_v2_df_runner
import
WanSkyreelsV2DFRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_vace_runner
import
WanVaceRunner
# noqa: F401
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.input_info
import
set_input_info
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
...
...
@@ -24,15 +25,15 @@ from lightx2v.utils.utils import seed_all
def
init_runner
(
config
):
seed_all
(
config
.
seed
)
torch
.
set_grad_enabled
(
False
)
runner
=
RUNNER_REGISTER
[
config
.
model_cls
](
config
)
runner
=
RUNNER_REGISTER
[
config
[
"
model_cls
"
]
](
config
)
runner
.
init_modules
()
return
runner
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"The seed for random generator"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
...
...
@@ -58,7 +59,7 @@ def main():
default
=
"wan2.1"
,
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
,
"t2i"
,
"i2i"
,
"flf2v"
,
"vace"
,
"animate"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
,
"t2i"
,
"i2i"
,
"flf2v"
,
"vace"
,
"animate"
,
"s2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--sf_model_path"
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
...
...
@@ -91,13 +92,16 @@ def main():
help
=
"The file of the source mask. Default None."
,
)
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
None
,
help
=
"The path to save video path/file"
)
parser
.
add_argument
(
"--save_result_path"
,
type
=
str
,
default
=
None
,
help
=
"The path to save video path/file"
)
parser
.
add_argument
(
"--return_result_tensor"
,
action
=
"store_true"
,
help
=
"Whether to return result tensor. (Useful for comfyui)"
)
args
=
parser
.
parse_args
()
seed_all
(
args
.
seed
)
# set config
config
=
set_config
(
args
)
if
config
.
parallel
:
if
config
[
"
parallel
"
]
:
dist
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
set_parallel_config
(
config
)
...
...
@@ -106,7 +110,8 @@ def main():
with
ProfilingContext4DebugL1
(
"Total Cost"
):
runner
=
init_runner
(
config
)
runner
.
run_pipeline
()
input_info
=
set_input_info
(
args
)
runner
.
run_pipeline
(
input_info
)
# Clean up distributed process group
if
dist
.
is_initialized
():
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
04812de2
...
...
@@ -2,6 +2,7 @@ import os
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
loguru
import
logger
from
lightx2v.models.networks.wan.infer.audio.post_infer
import
WanAudioPostInfer
...
...
@@ -35,11 +36,11 @@ class WanAudioModel(WanModel):
raise
ValueError
(
f
"Unsupported quant_scheme:
{
self
.
config
.
get
(
'adapter_quant_scheme'
,
None
)
}
"
)
else
:
adapter_model_name
=
"audio_adapter_model.safetensors"
self
.
config
.
adapter_model_path
=
os
.
path
.
join
(
self
.
config
.
model_path
,
adapter_model_name
)
self
.
config
[
"
adapter_model_path
"
]
=
os
.
path
.
join
(
self
.
config
[
"
model_path
"
]
,
adapter_model_name
)
adapter_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
load_from_rank0
=
self
.
config
.
get
(
"load_from_rank0"
,
False
)
self
.
adapter_weights_dict
=
load_weights
(
self
.
config
.
adapter_model_path
,
cpu_offload
=
adapter_offload
,
remove_key
=
"audio"
,
load_from_rank0
=
load_from_rank0
)
self
.
adapter_weights_dict
=
load_weights
(
self
.
config
[
"
adapter_model_path
"
]
,
cpu_offload
=
adapter_offload
,
remove_key
=
"audio"
,
load_from_rank0
=
load_from_rank0
)
if
not
adapter_offload
:
if
not
dist
.
is_initialized
()
or
not
load_from_rank0
:
for
key
in
self
.
adapter_weights_dict
:
...
...
@@ -51,17 +52,17 @@ class WanAudioModel(WanModel):
self
.
post_infer_class
=
WanAudioPostInfer
self
.
transformer_infer_class
=
WanAudioTransformerInfer
def
get_graph_name
(
self
,
shape
,
audio_num
):
return
f
"graph_
{
shape
[
0
]
}
x
{
shape
[
1
]
}
_
{
audio_num
}
audio"
def
get_graph_name
(
self
,
shape
,
audio_num
,
with_mask
=
True
):
return
f
"graph_
{
shape
[
0
]
}
x
{
shape
[
1
]
}
_audio_num
_
{
audio
_num
}
_mask_
{
with_mask
}
"
def
start_compile
(
self
,
shape
,
audio_num
):
graph_name
=
self
.
get_graph_name
(
shape
,
audio_num
)
def
start_compile
(
self
,
shape
,
audio_num
,
with_mask
=
True
):
graph_name
=
self
.
get_graph_name
(
shape
,
audio_num
,
with_mask
)
logger
.
info
(
f
"[Compile] Compile shape:
{
shape
}
, audio_num:
{
audio_num
}
, graph_name:
{
graph_name
}
"
)
target_video_length
=
self
.
config
.
get
(
"target_video_length"
,
81
)
latents_length
=
(
target_video_length
-
1
)
//
16
*
4
+
1
latents_h
=
shape
[
0
]
//
self
.
config
.
vae_stride
[
1
]
latents_w
=
shape
[
1
]
//
self
.
config
.
vae_stride
[
2
]
latents_h
=
shape
[
0
]
//
self
.
config
[
"
vae_stride
"
]
[
1
]
latents_w
=
shape
[
1
]
//
self
.
config
[
"
vae_stride
"
]
[
2
]
new_inputs
=
{}
new_inputs
[
"text_encoder_output"
]
=
{}
...
...
@@ -73,7 +74,11 @@ class WanAudioModel(WanModel):
new_inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
torch
.
randn
(
16
,
1
,
latents_h
,
latents_w
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"audio_encoder_output"
]
=
torch
.
randn
(
audio_num
,
latents_length
,
128
,
1024
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"person_mask_latens"
]
=
torch
.
zeros
(
audio_num
,
1
,
(
latents_h
//
2
),
(
latents_w
//
2
),
dtype
=
torch
.
int8
).
cuda
()
if
with_mask
:
new_inputs
[
"person_mask_latens"
]
=
torch
.
zeros
(
audio_num
,
1
,
(
latents_h
//
2
),
(
latents_w
//
2
),
dtype
=
torch
.
int8
).
cuda
()
else
:
assert
audio_num
==
1
,
"audio_num must be 1 when with_mask is False"
new_inputs
[
"person_mask_latens"
]
=
None
new_inputs
[
"previmg_encoder_output"
]
=
{}
new_inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
=
torch
.
randn
(
16
,
latents_length
,
latents_h
,
latents_w
,
dtype
=
torch
.
bfloat16
).
cuda
()
...
...
@@ -90,19 +95,21 @@ class WanAudioModel(WanModel):
self
.
enable_compile_mode
(
"_infer_cond_uncond"
)
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
[
"
model_cls
"
]
:
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
transformer_weights
.
non_block_weights_to_cuda
()
max_audio_num_num
=
self
.
config
.
get
(
"compile_max_audios"
,
1
)
max_audio_num_num
=
self
.
config
.
get
(
"compile_max_audios"
,
3
)
for
audio_num
in
range
(
1
,
max_audio_num_num
+
1
):
for
shape
in
compile_shapes
:
self
.
start_compile
(
shape
,
audio_num
)
self
.
start_compile
(
shape
,
audio_num
,
with_mask
=
True
)
if
audio_num
==
1
:
self
.
start_compile
(
shape
,
audio_num
,
with_mask
=
False
)
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
[
"
model_cls
"
]
:
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
...
...
@@ -115,9 +122,10 @@ class WanAudioModel(WanModel):
for
shape
in
compile_shapes
:
assert
shape
in
[[
480
,
832
],
[
544
,
960
],
[
720
,
1280
],
[
832
,
480
],
[
960
,
544
],
[
1280
,
720
],
[
480
,
480
],
[
576
,
576
],
[
704
,
704
],
[
960
,
960
]]
def
select_graph_for_compile
(
self
):
logger
.
info
(
f
"tgt_h, tgt_w :
{
self
.
config
.
get
(
'tgt_h'
)
}
,
{
self
.
config
.
get
(
'tgt_w'
)
}
, audio_num:
{
self
.
config
.
get
(
'audio_num'
)
}
"
)
self
.
select_graph
(
"_infer_cond_uncond"
,
f
"graph_
{
self
.
config
.
get
(
'tgt_h'
)
}
x
{
self
.
config
.
get
(
'tgt_w'
)
}
_
{
self
.
config
.
get
(
'audio_num'
)
}
audio"
)
def
select_graph_for_compile
(
self
,
input_info
):
logger
.
info
(
f
"target_h, target_w :
{
input_info
.
target_shape
[
0
]
}
,
{
input_info
.
target_shape
[
1
]
}
, audio_num:
{
input_info
.
audio_num
}
"
)
graph_name
=
self
.
get_graph_name
(
input_info
.
target_shape
,
input_info
.
audio_num
,
with_mask
=
input_info
.
with_mask
)
self
.
select_graph
(
"_infer_cond_uncond"
,
graph_name
)
logger
.
info
(
f
"[Compile] Compile status:
{
self
.
get_compile_status
()
}
"
)
@
torch
.
no_grad
()
...
...
@@ -138,7 +146,7 @@ class WanAudioModel(WanModel):
if
person_mask_latens
is
not
None
:
pre_infer_out
.
adapter_args
[
"person_mask_latens"
]
=
torch
.
chunk
(
person_mask_latens
,
world_size
,
dim
=
1
)[
cur_rank
]
if
self
.
config
[
"model_cls"
]
in
[
"wan2.2"
,
"wan2.2_audio"
]
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"model_cls"
]
in
[
"wan2.2"
,
"wan2.2_audio"
]
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
embed
,
embed0
=
pre_infer_out
.
embed
,
pre_infer_out
.
embed0
padding_size
=
(
world_size
-
(
embed
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
...
...
lightx2v/models/networks/wan/infer/audio/pre_infer.py
View file @
04812de2
...
...
@@ -33,7 +33,7 @@ class WanAudioPreInfer(WanPreInfer):
infer_condition
,
latents
,
timestep_input
=
self
.
scheduler
.
infer_condition
,
self
.
scheduler
.
latents
,
self
.
scheduler
.
timestep_input
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
hidden_states
=
latents
if
self
.
config
.
model_cls
!=
"wan2.2_audio"
:
if
self
.
config
[
"
model_cls
"
]
!=
"wan2.2_audio"
:
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
],
dim
=
0
)
...
...
@@ -101,7 +101,7 @@ class WanAudioPreInfer(WanPreInfer):
del
out
torch
.
cuda
.
empty_cache
()
if
self
.
task
==
"i2v"
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_clip
=
weights
.
proj_0
.
apply
(
clip_fea
)
if
self
.
clean_cuda_cache
:
del
clip_fea
...
...
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
View file @
04812de2
...
...
@@ -140,7 +140,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
def
infer_cross_attn
(
self
,
weights
,
x
,
context
,
block_idx
):
norm3_out
=
weights
.
norm3
.
apply
(
x
)
if
self
.
task
==
"i2v"
:
if
self
.
task
in
[
"i2v"
,
"s2v"
]
:
context_img
=
context
[:
257
]
context
=
context
[
257
:]
...
...
@@ -169,7 +169,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
model_cls
=
self
.
config
[
"model_cls"
],
)
if
self
.
task
==
"i2v"
:
if
self
.
task
in
[
"i2v"
,
"s2v"
]
:
k_img
=
weights
.
cross_attn_norm_k_img
.
apply
(
weights
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
v_img
=
weights
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
...
...
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
04812de2
...
...
@@ -25,22 +25,22 @@ class WanTransformerInferCaching(WanOffloadTransformerInfer):
class
WanTransformerInferTeaCaching
(
WanTransformerInferCaching
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
teacache_thresh
=
config
.
teacache_thresh
self
.
teacache_thresh
=
config
[
"
teacache_thresh
"
]
self
.
accumulated_rel_l1_distance_even
=
0
self
.
previous_e0_even
=
None
self
.
previous_residual_even
=
None
self
.
accumulated_rel_l1_distance_odd
=
0
self
.
previous_e0_odd
=
None
self
.
previous_residual_odd
=
None
self
.
use_ret_steps
=
config
.
use_ret_steps
self
.
use_ret_steps
=
config
[
"
use_ret_steps
"
]
if
self
.
use_ret_steps
:
self
.
coefficients
=
self
.
config
.
coefficients
[
0
]
self
.
coefficients
=
self
.
config
[
"
coefficients
"
]
[
0
]
self
.
ret_steps
=
5
self
.
cutoff_steps
=
self
.
config
.
infer_steps
self
.
cutoff_steps
=
self
.
config
[
"
infer_steps
"
]
else
:
self
.
coefficients
=
self
.
config
.
coefficients
[
1
]
self
.
coefficients
=
self
.
config
[
"
coefficients
"
]
[
1
]
self
.
ret_steps
=
1
self
.
cutoff_steps
=
self
.
config
.
infer_steps
-
1
self
.
cutoff_steps
=
self
.
config
[
"
infer_steps
"
]
-
1
# calculate should_calc
@
torch
.
no_grad
()
...
...
@@ -216,7 +216,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInferCaching, BaseTaylorCac
else
:
x
=
self
.
infer_using_cache
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
return
x
...
...
@@ -353,7 +353,7 @@ class WanTransformerInferAdaCaching(WanTransformerInferCaching):
else
:
x
=
self
.
infer_using_cache
(
xt
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
return
x
...
...
@@ -515,7 +515,7 @@ class AdaArgs:
# Moreg related attributes
self
.
previous_moreg
=
1.0
self
.
moreg_strides
=
[
1
]
self
.
moreg_steps
=
[
int
(
0.1
*
config
.
infer_steps
),
int
(
0.9
*
config
.
infer_steps
)]
self
.
moreg_steps
=
[
int
(
0.1
*
config
[
"
infer_steps
"
]
),
int
(
0.9
*
config
[
"
infer_steps
"
]
)]
self
.
moreg_hyp
=
[
0.385
,
8
,
1
,
2
]
self
.
mograd_mul
=
10
self
.
spatial_dim
=
1536
...
...
@@ -525,7 +525,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
cnt
=
0
self
.
teacache_thresh
=
config
.
teacache_thresh
self
.
teacache_thresh
=
config
[
"
teacache_thresh
"
]
self
.
accumulated_rel_l1_distance_even
=
0
self
.
previous_e0_even
=
None
self
.
previous_residual_even
=
None
...
...
@@ -534,15 +534,15 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
self
.
previous_residual_odd
=
None
self
.
cache_even
=
{}
self
.
cache_odd
=
{}
self
.
use_ret_steps
=
config
.
use_ret_steps
self
.
use_ret_steps
=
config
[
"
use_ret_steps
"
]
if
self
.
use_ret_steps
:
self
.
coefficients
=
self
.
config
.
coefficients
[
0
]
self
.
coefficients
=
self
.
config
[
"
coefficients
"
]
[
0
]
self
.
ret_steps
=
5
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
self
.
cutoff_steps
=
self
.
config
[
"
infer_steps
"
]
*
2
else
:
self
.
coefficients
=
self
.
config
.
coefficients
[
1
]
self
.
coefficients
=
self
.
config
[
"
coefficients
"
]
[
1
]
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
config
.
infer_steps
*
2
-
2
self
.
cutoff_steps
=
self
.
config
[
"
infer_steps
"
]
*
2
-
2
# 1. get taylor step_diff when there is two caching_records in scheduler
def
get_taylor_step_diff
(
self
):
...
...
@@ -625,7 +625,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
else
:
x
=
self
.
infer_using_cache
(
x
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
self
.
cnt
+=
1
...
...
@@ -690,12 +690,12 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
residual_diff_threshold
=
config
.
residual_diff_threshold
self
.
residual_diff_threshold
=
config
[
"
residual_diff_threshold
"
]
self
.
prev_first_block_residual_even
=
None
self
.
prev_remaining_blocks_residual_even
=
None
self
.
prev_first_block_residual_odd
=
None
self
.
prev_remaining_blocks_residual_odd
=
None
self
.
downsample_factor
=
self
.
config
.
downsample_factor
self
.
downsample_factor
=
self
.
config
[
"
downsample_factor
"
]
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
ori_x
=
x
.
clone
()
...
...
@@ -727,7 +727,7 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching):
else
:
x
=
self
.
infer_using_cache
(
x
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
return
x
...
...
@@ -795,12 +795,12 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
residual_diff_threshold
=
config
.
residual_diff_threshold
self
.
residual_diff_threshold
=
config
[
"
residual_diff_threshold
"
]
self
.
prev_front_blocks_residual_even
=
None
self
.
prev_middle_blocks_residual_even
=
None
self
.
prev_front_blocks_residual_odd
=
None
self
.
prev_middle_blocks_residual_odd
=
None
self
.
downsample_factor
=
self
.
config
.
downsample_factor
self
.
downsample_factor
=
self
.
config
[
"
downsample_factor
"
]
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
ori_x
=
x
.
clone
()
...
...
@@ -854,7 +854,7 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
context
,
)
if
self
.
config
.
enable_cfg
:
if
self
.
config
[
"
enable_cfg
"
]
:
self
.
switch_status
()
return
x
...
...
@@ -921,8 +921,8 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
class
WanTransformerInferDynamicBlock
(
WanTransformerInferCaching
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
residual_diff_threshold
=
config
.
residual_diff_threshold
self
.
downsample_factor
=
self
.
config
.
downsample_factor
self
.
residual_diff_threshold
=
config
[
"
residual_diff_threshold
"
]
self
.
downsample_factor
=
self
.
config
[
"
downsample_factor
"
]
self
.
block_in_cache_even
=
{
i
:
None
for
i
in
range
(
self
.
blocks_num
)}
self
.
block_residual_cache_even
=
{
i
:
None
for
i
in
range
(
self
.
blocks_num
)}
...
...
@@ -992,10 +992,10 @@ class WanTransformerInferDynamicBlock(WanTransformerInferCaching):
class
WanTransformerInferMagCaching
(
WanTransformerInferCaching
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
magcache_thresh
=
config
.
magcache_thresh
self
.
K
=
config
.
magcache_K
self
.
retention_ratio
=
config
.
magcache_retention_ratio
self
.
mag_ratios
=
np
.
array
(
config
.
magcache_ratios
)
self
.
magcache_thresh
=
config
[
"
magcache_thresh
"
]
self
.
K
=
config
[
"
magcache_K
"
]
self
.
retention_ratio
=
config
[
"
magcache_retention_ratio
"
]
self
.
mag_ratios
=
np
.
array
(
config
[
"
magcache_ratios
"
]
)
# {True: cond_param, False: uncond_param}
self
.
accumulated_err
=
{
True
:
0.0
,
False
:
0.0
}
self
.
accumulated_steps
=
{
True
:
0
,
False
:
0
}
...
...
@@ -1011,10 +1011,10 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
step_index
=
self
.
scheduler
.
step_index
infer_condition
=
self
.
scheduler
.
infer_condition
if
self
.
config
.
magcache_calibration
:
if
self
.
config
[
"
magcache_calibration
"
]
:
skip_forward
=
False
else
:
if
step_index
>=
int
(
self
.
config
.
infer_steps
*
self
.
retention_ratio
):
if
step_index
>=
int
(
self
.
config
[
"
infer_steps
"
]
*
self
.
retention_ratio
):
# conditional and unconditional in one list
cur_mag_ratio
=
self
.
mag_ratios
[
0
][
step_index
]
if
infer_condition
else
self
.
mag_ratios
[
1
][
step_index
]
# magnitude ratio between current step and the cached step
...
...
@@ -1054,7 +1054,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
if
self
.
config
[
"cpu_offload"
]:
previous_residual
=
previous_residual
.
cpu
()
if
self
.
config
.
magcache_calibration
and
step_index
>=
1
:
if
self
.
config
[
"
magcache_calibration
"
]
and
step_index
>=
1
:
norm_ratio
=
((
previous_residual
.
norm
(
dim
=-
1
)
/
self
.
residual_cache
[
infer_condition
].
norm
(
dim
=-
1
)).
mean
()).
item
()
norm_std
=
(
previous_residual
.
norm
(
dim
=-
1
)
/
self
.
residual_cache
[
infer_condition
].
norm
(
dim
=-
1
)).
std
().
item
()
cos_dis
=
(
1
-
F
.
cosine_similarity
(
previous_residual
,
self
.
residual_cache
[
infer_condition
],
dim
=-
1
,
eps
=
1e-8
)).
mean
().
item
()
...
...
@@ -1083,7 +1083,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
self
.
accumulated_steps
=
{
True
:
0
,
False
:
0
}
self
.
accumulated_ratio
=
{
True
:
1.0
,
False
:
1.0
}
self
.
residual_cache
=
{
True
:
None
,
False
:
None
}
if
self
.
config
.
magcache_calibration
:
if
self
.
config
[
"
magcache_calibration
"
]
:
print
(
"norm ratio"
)
print
(
self
.
norm_ratio
)
print
(
"norm std"
)
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
04812de2
...
...
@@ -41,7 +41,7 @@ class WanPreInfer:
else
:
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]:
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]:
if
self
.
config
.
get
(
"use_image_encoder"
,
True
):
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
...
...
lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py
View file @
04812de2
...
...
@@ -39,12 +39,12 @@ class WanSFTransformerInfer(WanTransformerInfer):
else
:
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
dtype
=
torch
.
bfloat16
sf_config
=
self
.
config
.
sf_config
self
.
local_attn_size
=
sf_config
.
local_attn_size
sf_config
=
self
.
config
[
"
sf_config
"
]
self
.
local_attn_size
=
sf_config
[
"
local_attn_size
"
]
self
.
max_attention_size
=
32760
if
self
.
local_attn_size
==
-
1
else
self
.
local_attn_size
*
1560
self
.
num_frame_per_block
=
sf_config
.
num_frame_per_block
self
.
num_transformer_blocks
=
sf_config
.
num_transformer_blocks
self
.
frame_seq_length
=
sf_config
.
frame_seq_length
self
.
num_frame_per_block
=
sf_config
[
"
num_frame_per_block
"
]
self
.
num_transformer_blocks
=
sf_config
[
"
num_transformer_blocks
"
]
self
.
frame_seq_length
=
sf_config
[
"
frame_seq_length
"
]
self
.
_initialize_kv_cache
(
self
.
device
,
self
.
dtype
)
self
.
_initialize_crossattn_cache
(
self
.
device
,
self
.
dtype
)
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
04812de2
...
...
@@ -11,13 +11,13 @@ from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, comp
class
WanTransformerInfer
(
BaseTransformerInfer
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
task
=
config
.
task
self
.
task
=
config
[
"
task
"
]
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
self
.
blocks_num
=
config
.
num_layers
self
.
blocks_num
=
config
[
"
num_layers
"
]
self
.
phases_num
=
3
self
.
has_post_adapter
=
False
self
.
num_heads
=
config
.
num_heads
self
.
head_dim
=
config
.
dim
//
config
.
num_heads
self
.
num_heads
=
config
[
"
num_heads
"
]
self
.
head_dim
=
config
[
"
dim
"
]
//
config
[
"
num_heads
"
]
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
parallel_attention
=
None
if
config
.
get
(
"rotary_chunk"
,
False
):
...
...
@@ -203,7 +203,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x
.
add_
(
y_out
*
gate_msa
.
squeeze
())
norm3_out
=
phase
.
norm3
.
apply
(
x
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_img
=
context
[:
257
]
context
=
context
[
257
:]
else
:
...
...
@@ -211,7 +211,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
context
=
context
.
to
(
self
.
infer_dtype
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_img
=
context_img
.
to
(
self
.
infer_dtype
)
n
,
d
=
self
.
num_heads
,
self
.
head_dim
...
...
@@ -234,7 +234,7 @@ class WanTransformerInfer(BaseTransformerInfer):
model_cls
=
self
.
config
[
"model_cls"
],
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
)
and
context_img
is
not
None
:
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"s2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
)
and
context_img
is
not
None
:
k_img
=
phase
.
cross_attn_norm_k_img
.
apply
(
phase
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
v_img
=
phase
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
...
...
lightx2v/models/networks/wan/model.py
View file @
04812de2
...
...
@@ -62,15 +62,15 @@ class WanModel(CompiledMethodsMixin):
self
.
init_empty_model
=
False
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
self
.
dit_quantized
=
self
.
config
[
"
mm_config
"
]
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
if
self
.
dit_quantized
:
dit_quant_scheme
=
self
.
config
.
mm_config
.
get
(
"mm_type"
).
split
(
"-"
)[
1
]
if
self
.
config
.
model_cls
==
"wan2.1_distill"
:
dit_quant_scheme
=
self
.
config
[
"
mm_config
"
]
.
get
(
"mm_type"
).
split
(
"-"
)[
1
]
if
self
.
config
[
"
model_cls
"
]
==
"wan2.1_distill"
:
dit_quant_scheme
=
"distill_"
+
dit_quant_scheme
if
dit_quant_scheme
==
"gguf"
:
self
.
dit_quantized_ckpt
=
find_gguf_model_path
(
config
,
"dit_quantized_ckpt"
,
subdir
=
dit_quant_scheme
)
self
.
config
.
use_gguf
=
True
self
.
config
[
"
use_gguf
"
]
=
True
else
:
self
.
dit_quantized_ckpt
=
find_hf_model_path
(
config
,
...
...
@@ -87,7 +87,7 @@ class WanModel(CompiledMethodsMixin):
self
.
dit_quantized_ckpt
=
None
assert
not
self
.
config
.
get
(
"lazy_load"
,
False
)
self
.
weight_auto_quant
=
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
self
.
weight_auto_quant
=
self
.
config
[
"
mm_config
"
]
.
get
(
"weight_auto_quant"
,
False
)
if
self
.
dit_quantized
:
assert
self
.
weight_auto_quant
or
self
.
dit_quantized_ckpt
is
not
None
...
...
@@ -158,7 +158,7 @@ class WanModel(CompiledMethodsMixin):
weight_dict
=
{}
for
file_path
in
safetensors_files
:
if
self
.
config
.
get
(
"adapter_model_path"
,
None
)
is
not
None
:
if
self
.
config
.
adapter_model_path
==
file_path
:
if
self
.
config
[
"
adapter_model_path
"
]
==
file_path
:
continue
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
unified_dtype
,
sensitive_layer
)
weight_dict
.
update
(
file_weights
)
...
...
@@ -367,7 +367,7 @@ class WanModel(CompiledMethodsMixin):
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
[
"
model_cls
"
]
:
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
...
...
@@ -400,7 +400,7 @@ class WanModel(CompiledMethodsMixin):
self
.
scheduler
.
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
infer_condition
=
True
)
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
[
"
model_cls
"
]
:
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
...
...
@@ -441,7 +441,7 @@ class WanModel(CompiledMethodsMixin):
pre_infer_out
.
x
=
torch
.
chunk
(
x
,
world_size
,
dim
=
0
)[
cur_rank
]
if
self
.
config
[
"model_cls"
]
in
[
"wan2.2"
,
"wan2.2_audio"
]
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"model_cls"
]
in
[
"wan2.2"
,
"wan2.2_audio"
]
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]
:
embed
,
embed0
=
pre_infer_out
.
embed
,
pre_infer_out
.
embed0
padding_size
=
(
world_size
-
(
embed
.
shape
[
0
]
%
world_size
))
%
world_size
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment