Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
9e120289
Commit
9e120289
authored
Aug 14, 2025
by
wangshankun
Browse files
Merge branch 'main' of
https://github.com/ModelTC/LightX2V
into main
parents
b5bcbed7
9196a220
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
51 additions
and
54 deletions
+51
-54
lightx2v/infer.py
lightx2v/infer.py
+17
-6
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+0
-1
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+1
-2
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+2
-2
lightx2v/models/networks/wan/causvid_model.py
lightx2v/models/networks/wan/causvid_model.py
+2
-2
lightx2v/models/networks/wan/distill_model.py
lightx2v/models/networks/wan/distill_model.py
+2
-2
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+0
-4
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+11
-9
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+1
-1
lightx2v/models/runners/wan/wan_causvid_runner.py
lightx2v/models/runners/wan/wan_causvid_runner.py
+0
-1
lightx2v/models/runners/wan/wan_distill_runner.py
lightx2v/models/runners/wan/wan_distill_runner.py
+0
-3
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+0
-9
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+3
-6
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+2
-3
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+10
-3
No files found.
lightx2v/infer.py
View file @
9e120289
import
argparse
import
json
import
torch.distributed
as
dist
from
loguru
import
logger
...
...
@@ -16,7 +15,7 @@ from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2D
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
set_config
,
set_parallel_config
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
from
lightx2v.utils.utils
import
seed_all
...
...
@@ -40,18 +39,31 @@ def main():
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
,
"wan2.2_moe"
,
"wan2.2"
,
"wan2.2_moe_audio"
,
"wan2.2_audio"
,
"wan2.2"
,
"wan2.2_moe_distill"
,
,
],
default
=
"wan2.1"
,
)
...
...
@@ -70,17 +82,16 @@ def main():
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
"./output_lightx2v.mp4"
,
help
=
"The path to save video path/file"
)
args
=
parser
.
parse_args
()
logger
.
info
(
f
"args:
{
args
}
"
)
# set config
config
=
set_config
(
args
)
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
if
config
.
parallel
:
dist
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
set_parallel_config
(
config
)
print_config
(
config
)
with
ProfilingContext
(
"Total Cost"
):
runner
=
init_runner
(
config
)
runner
.
run_pipeline
()
...
...
lightx2v/models/input_encoders/hf/t5/model.py
View file @
9e120289
...
...
@@ -540,7 +540,6 @@ class T5EncoderModel:
t5_quantized
=
False
,
t5_quantized_ckpt
=
None
,
quant_scheme
=
None
,
seq_p_group
=
None
,
):
self
.
text_len
=
text_len
self
.
dtype
=
dtype
...
...
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
9e120289
...
...
@@ -418,13 +418,12 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class
CLIPModel
:
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
,
seq_p_group
=
None
):
def
__init__
(
self
,
dtype
,
device
,
checkpoint_path
,
clip_quantized
,
clip_quantized_ckpt
,
quant_scheme
,
cpu_offload
=
False
,
use_31_block
=
True
):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
quantized
=
clip_quantized
self
.
cpu_offload
=
cpu_offload
self
.
use_31_block
=
use_31_block
self
.
seq_p_group
=
seq_p_group
if
self
.
quantized
:
self
.
checkpoint_path
=
clip_quantized_ckpt
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
9e120289
...
...
@@ -16,8 +16,8 @@ class WanAudioModel(WanModel):
post_weight_class
=
WanPostWeights
transformer_weight_class
=
WanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
,
seq_p_group
=
None
):
super
().
__init__
(
model_path
,
config
,
device
,
seq_p_group
)
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
def
_init_infer_class
(
self
):
super
().
_init_infer_class
()
...
...
lightx2v/models/networks/wan/causvid_model.py
View file @
9e120289
...
...
@@ -23,8 +23,8 @@ class WanCausVidModel(WanModel):
post_weight_class
=
WanPostWeights
transformer_weight_class
=
WanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
,
seq_p_group
=
None
):
super
().
__init__
(
model_path
,
config
,
device
,
seq_p_group
)
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanPreInfer
...
...
lightx2v/models/networks/wan/distill_model.py
View file @
9e120289
...
...
@@ -19,8 +19,8 @@ class WanDistillModel(WanModel):
post_weight_class
=
WanPostWeights
transformer_weight_class
=
WanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
,
seq_p_group
=
None
):
super
().
__init__
(
model_path
,
config
,
device
,
seq_p_group
)
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
# For the old t2v distill model: https://huggingface.co/lightx2v/Wan2.1-T2V-14B-StepDistill-CfgDistill
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
9e120289
...
...
@@ -74,10 +74,6 @@ class WanPreInfer:
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
long
).
cuda
().
unsqueeze
(
0
)
# wan2.2_moe会对t做扩展,我们发现这里做不做影响不大,而且做了拓展会增加耗时,目前忠实原作代码,后续可以考虑去掉
if
self
.
config
[
"model_cls"
]
==
"wan2.2_moe"
:
t
=
t
.
expand
(
seq_lens
[
0
])
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
if
self
.
enable_dynamic_cfg
:
s
=
torch
.
tensor
([
self
.
cfg_scale
],
dtype
=
torch
.
float32
).
to
(
x
.
device
)
...
...
lightx2v/models/networks/wan/model.py
View file @
9e120289
...
...
@@ -41,12 +41,16 @@ class WanModel:
post_weight_class
=
WanPostWeights
transformer_weight_class
=
WanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
,
seq_p_group
=
None
):
def
__init__
(
self
,
model_path
,
config
,
device
):
self
.
model_path
=
model_path
self
.
config
=
config
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
seq_p_group
=
seq_p_group
if
self
.
config
[
"seq_parallel"
]:
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
else
:
self
.
seq_p_group
=
None
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
...
...
@@ -252,8 +256,6 @@ class WanModel:
if
target_device
==
"cuda"
:
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
else
:
dist
.
barrier
()
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
if
is_weight_loader
:
...
...
@@ -390,11 +392,11 @@ class WanModel:
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
padding_size
))
# (后维度填充, 前维度填充)
x
=
torch
.
chunk
(
x
,
world_size
,
dim
=
0
)[
cur_rank
]
if
self
.
config
[
"model_cls"
]
.
startswith
(
"wan2.2"
)
:
padding_size
=
(
world_size
-
(
embed0
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
embed0
=
F
.
pad
(
embed0
,
(
0
,
0
,
0
,
0
,
0
,
padding_size
))
# (后维度填充, 前维度填充)
embed
=
F
.
pad
(
embed
,
(
0
,
0
,
0
,
padding_size
))
#
if self.config["model_cls"]
==
"wan2.2":
#
padding_size = (world_size - (embed0.shape[0] % world_size)) % world_size
#
if padding_size > 0:
#
embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) # (后维度填充, 前维度填充)
#
embed = F.pad(embed, (0, 0, 0, padding_size))
pre_infer_out
.
x
=
x
pre_infer_out
.
embed
=
embed
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
9e120289
...
...
@@ -343,7 +343,7 @@ class VideoGenerator:
self
.
model
.
scheduler
.
reset
()
inputs
[
"previmg_encoder_output"
]
=
self
.
prepare_prev_latents
(
prev_video
,
prev_frame_length
)
# Run inference loop
if
total_steps
is
None
:
total_steps
=
self
.
model
.
scheduler
.
infer_steps
...
...
lightx2v/models/runners/wan/wan_causvid_runner.py
View file @
9e120289
...
...
@@ -29,7 +29,6 @@ class WanCausVidRunner(WanRunner):
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
,
)
lora_wrapper
=
WanLoraWrapper
(
model
)
for
lora_config
in
self
.
config
.
lora_configs
:
...
...
lightx2v/models/runners/wan/wan_distill_runner.py
View file @
9e120289
...
...
@@ -21,7 +21,6 @@ class WanDistillRunner(WanRunner):
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
,
)
lora_wrapper
=
WanLoraWrapper
(
model
)
for
lora_config
in
self
.
config
.
lora_configs
:
...
...
@@ -91,7 +90,6 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os
.
path
.
join
(
self
.
config
.
model_path
,
"high_noise_model"
),
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
,
)
high_lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
for
lora_config
in
self
.
config
.
lora_configs
:
...
...
@@ -106,7 +104,6 @@ class Wan22MoeDistillRunner(WanDistillRunner):
os
.
path
.
join
(
self
.
config
.
model_path
,
"distill_models"
,
"high_noise_model"
),
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
,
)
if
use_low_lora
:
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
9e120289
...
...
@@ -34,18 +34,12 @@ from lightx2v.utils.utils import best_output_size, cache_video
class
WanRunner
(
DefaultRunner
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
device_mesh
=
self
.
config
.
get
(
"device_mesh"
)
if
device_mesh
is
not
None
:
self
.
seq_p_group
=
device_mesh
.
get_group
(
mesh_dim
=
"seq_p"
)
else
:
self
.
seq_p_group
=
None
def
load_transformer
(
self
):
model
=
WanModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
,
self
.
seq_p_group
,
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
...
...
@@ -83,7 +77,6 @@ class WanRunner(DefaultRunner):
clip_quantized
=
clip_quantized
,
clip_quantized_ckpt
=
clip_quantized_ckpt
,
quant_scheme
=
clip_quant_scheme
,
seq_p_group
=
self
.
seq_p_group
,
cpu_offload
=
self
.
config
.
get
(
"clip_cpu_offload"
,
self
.
config
.
get
(
"cpu_offload"
,
False
)),
use_31_block
=
self
.
config
.
get
(
"use_31_block"
,
True
),
)
...
...
@@ -127,7 +120,6 @@ class WanRunner(DefaultRunner):
t5_quantized
=
t5_quantized
,
t5_quantized_ckpt
=
t5_quantized_ckpt
,
quant_scheme
=
t5_quant_scheme
,
seq_p_group
=
self
.
seq_p_group
,
)
text_encoders
=
[
text_encoder
]
return
text_encoders
...
...
@@ -145,7 +137,6 @@ class WanRunner(DefaultRunner):
"device"
:
vae_device
,
"parallel"
:
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"vae_p_size"
,
False
)
and
self
.
config
.
parallel
.
vae_p_size
>
1
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"seq_p_group"
:
self
.
seq_p_group
,
"cpu_offload"
:
vae_offload
,
}
if
self
.
config
.
task
!=
"i2v"
:
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
9e120289
...
...
@@ -60,12 +60,9 @@ class WanScheduler(BaseScheduler):
device
=
self
.
device
,
generator
=
self
.
generator
,
)
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
:
if
self
.
config
[
"task"
]
==
"t2v"
:
self
.
mask
=
masks_like
(
self
.
latents
,
zero
=
False
)
elif
self
.
config
[
"task"
]
==
"i2v"
:
self
.
mask
=
masks_like
(
self
.
latents
,
zero
=
True
)
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
vae_encoder_out
+
self
.
mask
*
self
.
latents
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
self
.
mask
=
masks_like
(
self
.
latents
,
zero
=
True
)
self
.
latents
=
(
1.0
-
self
.
mask
)
*
self
.
vae_encoder_out
+
self
.
mask
*
self
.
latents
def
set_timesteps
(
self
,
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
9e120289
...
...
@@ -759,7 +759,7 @@ class WanVAE_(nn.Module):
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
None
,
device
=
"cpu"
,
seq_p_group
=
None
,
cpu_offload
=
False
,
**
kwargs
):
def
_video_vae
(
pretrained_path
=
None
,
z_dim
=
None
,
device
=
"cpu"
,
cpu_offload
=
False
,
**
kwargs
):
"""
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
"""
...
...
@@ -795,7 +795,6 @@ class WanVAE:
device
=
"cuda"
,
parallel
=
False
,
use_tiling
=
False
,
seq_p_group
=
None
,
cpu_offload
=
False
,
):
self
.
dtype
=
dtype
...
...
@@ -845,7 +844,7 @@ class WanVAE:
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
# init model
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
seq_p_group
=
seq_p_group
,
cpu_offload
=
cpu_offload
).
eval
().
requires_grad_
(
False
).
to
(
device
)
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
cpu_offload
=
cpu_offload
).
eval
().
requires_grad_
(
False
).
to
(
device
)
def
current_device
(
self
):
return
next
(
self
.
model
.
parameters
()).
device
...
...
lightx2v/utils/set_config.py
View file @
9e120289
...
...
@@ -69,9 +69,6 @@ def set_config(args):
def
set_parallel_config
(
config
):
if
config
.
parallel
:
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
"nccl"
)
cfg_p_size
=
config
.
parallel
.
get
(
"cfg_p_size"
,
1
)
seq_p_size
=
config
.
parallel
.
get
(
"seq_p_size"
,
1
)
assert
cfg_p_size
*
seq_p_size
==
dist
.
get_world_size
(),
f
"cfg_p_size * seq_p_size must be equal to world_size"
...
...
@@ -82,3 +79,13 @@ def set_parallel_config(config):
if
config
.
get
(
"enable_cfg"
,
False
)
and
config
.
parallel
and
config
.
parallel
.
get
(
"cfg_p_size"
,
False
)
and
config
.
parallel
.
cfg_p_size
>
1
:
config
[
"cfg_parallel"
]
=
True
def
print_config
(
config
):
config_to_print
=
config
.
copy
()
config_to_print
.
pop
(
"device_mesh"
,
None
)
if
config
.
parallel
:
if
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config_to_print
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
else
:
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config_to_print
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment