Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
978e3b32
Commit
978e3b32
authored
Aug 06, 2025
by
helloyongyang
Browse files
update wan2.2moe
parent
d50b8884
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
72 additions
and
53 deletions
+72
-53
configs/dist_infer/wan22_moe_t2v_cfg.json
configs/dist_infer/wan22_moe_t2v_cfg.json
+20
-0
lightx2v/models/networks/wan/distill_model.py
lightx2v/models/networks/wan/distill_model.py
+3
-9
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+2
-35
lightx2v/models/runners/wan/wan_distill_runner.py
lightx2v/models/runners/wan/wan_distill_runner.py
+3
-3
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+3
-3
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+3
-3
scripts/dist_infer/run_wan22_moe_t2v_cfg.sh
scripts/dist_infer/run_wan22_moe_t2v_cfg.sh
+38
-0
No files found.
configs/dist_infer/wan22_moe_t2v_cfg.json
0 → 100755
View file @
978e3b32
{
"infer_steps"
:
40
,
"target_video_length"
:
81
,
"text_len"
:
512
,
"target_height"
:
720
,
"target_width"
:
1280
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
4.0
,
3.0
],
"sample_shift"
:
12.0
,
"enable_cfg"
:
true
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"model"
,
"boundary"
:
0.875
,
"parallel"
:
{
"cfg_p_size"
:
2
}
}
lightx2v/models/networks/wan/distill_model.py
View file @
978e3b32
...
...
@@ -3,7 +3,7 @@ import os
import
torch
from
loguru
import
logger
from
lightx2v.models.networks.wan.model
import
Wan22MoeModel
,
WanModel
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.weights.post_weights
import
WanPostWeights
from
lightx2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
from
lightx2v.models.networks.wan.weights.transformer_weights
import
(
...
...
@@ -32,16 +32,10 @@ class WanDistillModel(WanModel):
return
super
().
_load_ckpt
(
unified_dtype
,
sensitive_layer
)
class
Wan22MoeDistillModel
(
WanDistillModel
,
Wan
22Moe
Model
):
class
Wan22MoeDistillModel
(
WanDistillModel
,
WanModel
):
def
__init__
(
self
,
model_path
,
config
,
device
):
WanDistillModel
.
__init__
(
self
,
model_path
,
config
,
device
)
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"distill_model.safetensors"
)
if
os
.
path
.
exists
(
ckpt_path
):
logger
.
info
(
f
"Loading weights from
{
ckpt_path
}
"
)
return
self
.
_load_safetensor_to_dict
(
ckpt_path
,
unified_dtype
,
sensitive_layer
)
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
return
Wan
22Moe
Model
.
infer
(
self
,
inputs
)
return
WanModel
.
infer
(
self
,
inputs
)
lightx2v/models/networks/wan/model.py
View file @
978e3b32
...
...
@@ -55,7 +55,7 @@ class WanModel:
self
.
dit_quantized_ckpt
=
find_gguf_model_path
(
config
,
"dit_quantized_ckpt"
,
subdir
=
dit_quant_scheme
)
self
.
config
.
use_gguf
=
True
else
:
self
.
dit_quantized_ckpt
=
find_hf_model_path
(
config
,
"dit_quantized_ckpt"
,
subdir
=
dit_quant_scheme
)
self
.
dit_quantized_ckpt
=
find_hf_model_path
(
config
,
self
.
model_path
,
"dit_quantized_ckpt"
,
subdir
=
dit_quant_scheme
)
quant_config_path
=
os
.
path
.
join
(
self
.
dit_quantized_ckpt
,
"config.json"
)
if
os
.
path
.
exists
(
quant_config_path
):
with
open
(
quant_config_path
,
"r"
)
as
f
:
...
...
@@ -106,7 +106,7 @@ class WanModel:
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
)).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()}
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
safetensors_path
=
find_hf_model_path
(
self
.
config
,
"dit_original_ckpt"
,
subdir
=
"original"
)
safetensors_path
=
find_hf_model_path
(
self
.
config
,
self
.
model_path
,
"dit_original_ckpt"
,
subdir
=
"original"
)
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
safetensors_path
,
"*.safetensors"
))
weight_dict
=
{}
for
file_path
in
safetensors_files
:
...
...
@@ -293,36 +293,3 @@ class WanModel:
noise_pred_cond
=
noise_pred_list
[
0
]
# cfg_p_rank == 0
noise_pred_uncond
=
noise_pred_list
[
1
]
# cfg_p_rank == 1
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
class
Wan22MoeModel
(
WanModel
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
self
.
model_path
,
"*.safetensors"
))
weight_dict
=
{}
for
file_path
in
safetensors_files
:
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
unified_dtype
,
sensitive_layer
)
weight_dict
.
update
(
file_weights
)
return
weight_dict
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
if
self
.
cpu_offload
and
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
self
.
scheduler
.
noise_pred
-
noise_pred_uncond
)
if
self
.
cpu_offload
and
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
lightx2v/models/runners/wan/wan_distill_runner.py
View file @
978e3b32
...
...
@@ -4,7 +4,7 @@ from loguru import logger
from
lightx2v.models.networks.wan.distill_model
import
Wan22MoeDistillModel
,
WanDistillModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.model
import
Wan22MoeModel
,
WanModel
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.runners.wan.wan_runner
import
MultiModelStruct
,
WanRunner
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
Wan22StepDistillScheduler
,
WanStepDistillScheduler
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
...
@@ -86,7 +86,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
use_low_lora
=
True
if
use_high_lora
:
high_noise_model
=
Wan
22Moe
Model
(
high_noise_model
=
WanModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"high_noise_model"
),
self
.
config
,
self
.
init_device
,
...
...
@@ -107,7 +107,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
)
if
use_low_lora
:
low_noise_model
=
Wan
22Moe
Model
(
low_noise_model
=
WanModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"low_noise_model"
),
self
.
config
,
self
.
init_device
,
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
978e3b32
...
...
@@ -11,7 +11,7 @@ from loguru import logger
from
lightx2v.models.input_encoders.hf.t5.model
import
T5EncoderModel
from
lightx2v.models.input_encoders.hf.xlm_roberta.model
import
CLIPModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.model
import
Wan22MoeModel
,
WanModel
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.schedulers.wan.changing_resolution.scheduler
import
(
WanScheduler4ChangingResolutionInterface
,
...
...
@@ -370,12 +370,12 @@ class Wan22MoeRunner(WanRunner):
def
load_transformer
(
self
):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model
=
Wan
22Moe
Model
(
high_noise_model
=
WanModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"high_noise_model"
),
self
.
config
,
self
.
init_device
,
)
low_noise_model
=
Wan
22Moe
Model
(
low_noise_model
=
WanModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"low_noise_model"
),
self
.
config
,
self
.
init_device
,
...
...
lightx2v/utils/utils.py
View file @
978e3b32
...
...
@@ -277,14 +277,14 @@ def find_torch_model_path(config, ckpt_config_key=None, filename=None, subdir=["
raise
FileNotFoundError
(
f
"PyTorch model file '
{
filename
}
' not found.
\n
Please download the model from https://huggingface.co/lightx2v/ or specify the model path in the configuration file."
)
def
find_hf_model_path
(
config
,
ckpt_config_key
=
None
,
subdir
=
[
"original"
,
"fp8"
,
"int8"
]):
def
find_hf_model_path
(
config
,
model_path
,
ckpt_config_key
=
None
,
subdir
=
[
"original"
,
"fp8"
,
"int8"
]):
if
ckpt_config_key
and
config
.
get
(
ckpt_config_key
,
None
)
is
not
None
:
return
config
.
get
(
ckpt_config_key
)
paths_to_check
=
[
config
.
model_path
]
paths_to_check
=
[
model_path
]
if
isinstance
(
subdir
,
list
):
for
sub
in
subdir
:
paths_to_check
.
append
(
os
.
path
.
join
(
config
.
model_path
,
sub
))
paths_to_check
.
append
(
os
.
path
.
join
(
model_path
,
sub
))
else
:
paths_to_check
.
append
(
os
.
path
.
join
(
config
.
model_path
,
subdir
))
...
...
scripts/dist_infer/run_wan22_moe_t2v_cfg.sh
0 → 100755
View file @
978e3b32
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0,1
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using default value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
torchrun
--nproc_per_node
=
2
-m
lightx2v.infer
\
--model_cls
wan2.2_moe
\
--task
t2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/dist_infer/wan22_moe_t2v_cfg.json
\
--prompt
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
\
--negative_prompt
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan22_moe_t2v_parallel_cfg.mp4
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment