Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
ec79c145
Commit
ec79c145
authored
Jul 29, 2025
by
helloyongyang
Browse files
Support wan2.2 moe t2v model
parent
6e46224f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
175 additions
and
20 deletions
+175
-20
configs/wan22/wan_t2v.json
configs/wan22/wan_t2v.json
+17
-0
lightx2v/api_server.py
lightx2v/api_server.py
+2
-1
lightx2v/infer.py
lightx2v/infer.py
+2
-2
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+18
-15
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+26
-1
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+67
-1
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+1
-0
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+4
-0
scripts/wan/run_wan22_moe_t2v.sh
scripts/wan/run_wan22_moe_t2v.sh
+38
-0
No files found.
configs/wan22/wan_t2v.json
0 → 100755
View file @
ec79c145
{
"infer_steps"
:
40
,
"target_video_length"
:
81
,
"text_len"
:
512
,
"target_height"
:
720
,
"target_width"
:
1280
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
3.0
,
4.0
],
"sample_shift"
:
12.0
,
"enable_cfg"
:
true
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"model"
,
"boundary"
:
0.875
}
lightx2v/api_server.py
View file @
ec79c145
...
...
@@ -40,8 +40,9 @@ def main():
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"wan2.1_audio"
,
"wan2.2_moe"
,
],
default
=
"
hunyuan
"
,
default
=
"
wan2.1
"
,
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
...
...
lightx2v/infer.py
View file @
ec79c145
...
...
@@ -10,7 +10,7 @@ from lightx2v.utils.set_config import set_config
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
,
Wan22MoeRunner
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
from
lightx2v.models.runners.wan.wan_audio_runner
import
WanAudioRunner
...
...
@@ -42,7 +42,7 @@ def init_runner(config):
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
],
default
=
"wan2.1"
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
,
"wan2.2_moe"
],
default
=
"wan2.1"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
ec79c145
...
...
@@ -46,22 +46,25 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
infer_func
=
self
.
_infer_with_phases_offload
else
:
self
.
infer_func
=
self
.
_infer_with_phases_lazy_offload
elif
offload_granularity
==
"model"
:
self
.
infer_func
=
self
.
_infer_without_offload
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
weights_stream_mgr
=
WeightAsyncStreamManager
(
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
offload_ratio
,
phases_num
=
self
.
phases_num
,
)
else
:
self
.
weights_stream_mgr
=
LazyWeightAsyncStreamManager
(
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
offload_ratio
,
phases_num
=
self
.
phases_num
,
num_disk_workers
=
self
.
config
.
get
(
"num_disk_workers"
,
2
),
max_memory
=
self
.
config
.
get
(
"max_memory"
,
2
),
offload_gra
=
offload_granularity
,
)
if
offload_granularity
!=
"model"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
weights_stream_mgr
=
WeightAsyncStreamManager
(
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
offload_ratio
,
phases_num
=
self
.
phases_num
,
)
else
:
self
.
weights_stream_mgr
=
LazyWeightAsyncStreamManager
(
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
offload_ratio
,
phases_num
=
self
.
phases_num
,
num_disk_workers
=
self
.
config
.
get
(
"num_disk_workers"
,
2
),
max_memory
=
self
.
config
.
get
(
"max_memory"
,
2
),
offload_gra
=
offload_granularity
,
)
else
:
self
.
infer_func
=
self
.
_infer_without_offload
...
...
lightx2v/models/networks/wan/model.py
View file @
ec79c145
...
...
@@ -226,7 +226,7 @@ class WanModel:
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
config
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
self
.
pre_weight
.
to_cpu
()
...
...
@@ -235,3 +235,28 @@ class WanModel:
if
self
.
clean_cuda_cache
:
del
x
,
embed
,
pre_infer_out
,
noise_pred_uncond
,
grid_sizes
torch
.
cuda
.
empty_cache
()
class
Wan22MoeModel
(
WanModel
):
def
_load_ckpt
(
self
,
use_bf16
,
skip_bf16
):
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
self
.
model_path
,
"*.safetensors"
))
weight_dict
=
{}
for
file_path
in
safetensors_files
:
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
use_bf16
,
skip_bf16
)
weight_dict
.
update
(
file_weights
)
return
weight_dict
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
lightx2v/models/runners/wan/wan_runner.py
View file @
ec79c145
...
...
@@ -18,7 +18,7 @@ from lightx2v.utils.profiler import ProfilingContext
from
lightx2v.utils.utils
import
*
from
lightx2v.models.input_encoders.hf.t5.model
import
T5EncoderModel
from
lightx2v.models.input_encoders.hf.xlm_roberta.model
import
CLIPModel
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.model
import
WanModel
,
Wan22MoeModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae_tiny
import
WanVAE_tiny
...
...
@@ -293,3 +293,69 @@ class WanRunner(DefaultRunner):
normalize
=
True
,
value_range
=
(
-
1
,
1
),
)
class
MultiModelStruct
:
def
__init__
(
self
,
model_list
,
config
,
boundary
=
0.875
,
num_train_timesteps
=
1000
):
self
.
model
=
model_list
# [high_noise_model, low_noise_model]
assert
len
(
self
.
model
)
==
2
,
"MultiModelStruct only supports 2 models now."
self
.
config
=
config
self
.
boundary
=
boundary
self
.
boundary_timestep
=
self
.
boundary
*
num_train_timesteps
self
.
cur_model_index
=
-
1
logger
.
info
(
f
"boundary:
{
self
.
boundary
}
, boundary_timestep:
{
self
.
boundary_timestep
}
"
)
def
set_scheduler
(
self
,
shared_scheduler
):
self
.
scheduler
=
shared_scheduler
for
model
in
self
.
model
:
model
.
set_scheduler
(
shared_scheduler
)
def
infer
(
self
,
inputs
):
self
.
get_current_model_index
()
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
def
get_current_model_index
(
self
):
if
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]
>=
self
.
boundary_timestep
:
logger
.
info
(
f
"using - HIGH - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
self
.
scheduler
.
sample_guide_scale
=
self
.
config
.
sample_guide_scale
[
0
]
if
self
.
cur_model_index
==
-
1
:
self
.
to_cuda
(
model_index
=
0
)
elif
self
.
cur_model_index
==
1
:
# 1 -> 0
self
.
offload_cpu
(
model_index
=
1
)
self
.
to_cuda
(
model_index
=
0
)
self
.
cur_model_index
=
0
else
:
logger
.
info
(
f
"using - LOW - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
self
.
scheduler
.
sample_guide_scale
=
self
.
config
.
sample_guide_scale
[
1
]
if
self
.
cur_model_index
==
-
1
:
self
.
to_cuda
(
model_index
=
1
)
elif
self
.
cur_model_index
==
0
:
# 0 -> 1
self
.
offload_cpu
(
model_index
=
0
)
self
.
to_cuda
(
model_index
=
1
)
self
.
cur_model_index
=
1
def
offload_cpu
(
self
,
model_index
):
self
.
model
[
model_index
].
to_cpu
()
def
to_cuda
(
self
,
model_index
):
self
.
model
[
model_index
].
to_cuda
()
@
RUNNER_REGISTER
(
"wan2.2_moe"
)
class
Wan22MoeRunner
(
WanRunner
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
load_transformer
(
self
):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model
=
Wan22MoeModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"high_noise_model"
),
self
.
config
,
self
.
init_device
,
)
low_noise_model
=
Wan22MoeModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"low_noise_model"
),
self
.
config
,
self
.
init_device
,
)
return
MultiModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
.
boundary
)
lightx2v/models/schedulers/wan/scheduler.py
View file @
ec79c145
...
...
@@ -18,6 +18,7 @@ class WanScheduler(BaseScheduler):
self
.
disable_corrector
=
[]
self
.
solver_order
=
2
self
.
noise_pred
=
None
self
.
sample_guide_scale
=
self
.
config
.
sample_guide_scale
self
.
caching_records_2
=
[
True
]
*
self
.
config
.
infer_steps
...
...
lightx2v/utils/set_config.py
View file @
ec79c145
...
...
@@ -37,6 +37,10 @@ def set_config(args):
with
open
(
os
.
path
.
join
(
config
.
model_path
,
"config.json"
),
"r"
)
as
f
:
model_config
=
json
.
load
(
f
)
config
.
update
(
model_config
)
elif
os
.
path
.
exists
(
os
.
path
.
join
(
config
.
model_path
,
"low_noise_model"
,
"config.json"
)):
# 需要一个更优雅的update方法
with
open
(
os
.
path
.
join
(
config
.
model_path
,
"low_noise_model"
,
"config.json"
),
"r"
)
as
f
:
model_config
=
json
.
load
(
f
)
config
.
update
(
model_config
)
elif
os
.
path
.
exists
(
os
.
path
.
join
(
config
.
model_path
,
"original"
,
"config.json"
)):
with
open
(
os
.
path
.
join
(
config
.
model_path
,
"original"
,
"config.json"
),
"r"
)
as
f
:
model_config
=
json
.
load
(
f
)
...
...
scripts/wan/run_wan22_moe_t2v.sh
0 → 100755
View file @
ec79c145
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using default value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
python
-m
lightx2v.infer
\
--model_cls
wan2.2_moe
\
--task
t2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/wan22/wan_t2v.json
\
--prompt
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
\
--negative_prompt
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan22_moe_t2v.mp4
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment