Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
9e3680b7
Commit
9e3680b7
authored
Aug 14, 2025
by
helloyongyang
Browse files
fix ci
parent
7367d6c8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
22 deletions
+29
-22
lightx2v/infer.py
lightx2v/infer.py
+15
-2
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+5
-6
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+2
-8
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+7
-5
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
+0
-1
No files found.
lightx2v/infer.py
View file @
9e3680b7
...
@@ -8,7 +8,7 @@ from lightx2v.common.ops import *
...
@@ -8,7 +8,7 @@ from lightx2v.common.ops import *
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
# noqa: F401
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
# noqa: F401
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
# noqa: F401
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22MoeAudioRunner
,
WanAudioRunner
,
Wan22AudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22AudioRunner
,
Wan22MoeAudioRunner
,
WanAudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_runner
import
Wan22MoeRunner
,
WanRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_runner
import
Wan22MoeRunner
,
WanRunner
# noqa: F401
...
@@ -39,7 +39,20 @@ def main():
...
@@ -39,7 +39,20 @@ def main():
"--model_cls"
,
"--model_cls"
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
,
"wan2.2_moe"
,
"wan2.2_moe_audio"
,
"wan2.2_audio"
,
"wan2.2"
,
"wan2.2_moe_distill"
],
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
,
"wan2.2_moe"
,
"wan2.2"
,
"wan2.2_moe_audio"
,
"wan2.2_audio"
,
"wan2.2_moe_distill"
,
],
default
=
"wan2.1"
,
default
=
"wan2.1"
,
)
)
...
...
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
9e3680b7
...
@@ -4,8 +4,8 @@ from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
...
@@ -4,8 +4,8 @@ from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
..module_io
import
WanPreInferModuleOutput
from
..module_io
import
WanPreInferModuleOutput
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
,
masks_like
from
..utils
import
masks_like
,
rope_params
,
sinusoidal_embedding_1d
from
loguru
import
logger
class
WanAudioPreInfer
(
WanPreInfer
):
class
WanAudioPreInfer
(
WanPreInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -29,12 +29,11 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -29,12 +29,11 @@ class WanAudioPreInfer(WanPreInfer):
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
infer
(
self
,
weights
,
inputs
,
positive
):
def
infer
(
self
,
weights
,
inputs
,
positive
):
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
hidden_states
=
self
.
scheduler
.
latents
hidden_states
=
self
.
scheduler
.
latents
mask1
,
mask2
=
masks_like
([
hidden_states
],
zero
=
True
,
prev_length
=
hidden_states
.
shape
[
1
])
mask1
,
mask2
=
masks_like
([
hidden_states
],
zero
=
True
,
prev_length
=
hidden_states
.
shape
[
1
])
hidden_states
=
(
1.
-
mask2
[
0
])
*
prev_latents
+
mask2
[
0
]
*
hidden_states
hidden_states
=
(
1.
0
-
mask2
[
0
])
*
prev_latents
+
mask2
[
0
]
*
hidden_states
else
:
else
:
prev_latents
=
prev_latents
.
unsqueeze
(
0
)
prev_latents
=
prev_latents
.
unsqueeze
(
0
)
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
...
@@ -53,7 +52,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -53,7 +52,7 @@ class WanAudioPreInfer(WanPreInfer):
"timestep"
:
t
,
"timestep"
:
t
,
}
}
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
audio_dit_blocks
=
None
##Debug Drop Audio
audio_dit_blocks
=
None
##Debug Drop Audio
if
positive
:
if
positive
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
context
=
inputs
[
"text_encoder_output"
][
"context"
]
...
@@ -66,7 +65,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -66,7 +65,7 @@ class WanAudioPreInfer(WanPreInfer):
batch_size
=
len
(
x
)
batch_size
=
len
(
x
)
num_channels
,
_
,
height
,
width
=
x
[
0
].
shape
num_channels
,
_
,
height
,
width
=
x
[
0
].
shape
_
,
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
_
,
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
if
ref_num_channels
!=
num_channels
:
if
ref_num_channels
!=
num_channels
:
zero_padding
=
torch
.
zeros
(
zero_padding
=
torch
.
zeros
(
(
batch_size
,
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
(
batch_size
,
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
9e3680b7
...
@@ -15,15 +15,9 @@ def masks_like(tensor, zero=False, generator=None, p=0.2, prev_length=1):
...
@@ -15,15 +15,9 @@ def masks_like(tensor, zero=False, generator=None, p=0.2, prev_length=1):
if
zero
:
if
zero
:
if
generator
is
not
None
:
if
generator
is
not
None
:
for
u
,
v
in
zip
(
out1
,
out2
):
for
u
,
v
in
zip
(
out1
,
out2
):
random_num
=
torch
.
rand
(
random_num
=
torch
.
rand
(
1
,
generator
=
generator
,
device
=
generator
.
device
).
item
()
1
,
generator
=
generator
,
device
=
generator
.
device
).
item
()
if
random_num
<
p
:
if
random_num
<
p
:
u
[:,
:
prev_length
]
=
torch
.
normal
(
u
[:,
:
prev_length
]
=
torch
.
normal
(
mean
=-
3.5
,
std
=
0.5
,
size
=
(
1
,),
device
=
u
.
device
,
generator
=
generator
).
expand_as
(
u
[:,
:
prev_length
]).
exp
()
mean
=-
3.5
,
std
=
0.5
,
size
=
(
1
,),
device
=
u
.
device
,
generator
=
generator
).
expand_as
(
u
[:,
:
prev_length
]).
exp
()
v
[:,
:
prev_length
]
=
torch
.
zeros_like
(
v
[:,
:
prev_length
])
v
[:,
:
prev_length
]
=
torch
.
zeros_like
(
v
[:,
:
prev_length
])
else
:
else
:
u
[:,
:
prev_length
]
=
u
[:,
:
prev_length
]
u
[:,
:
prev_length
]
=
u
[:,
:
prev_length
]
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
9e3680b7
...
@@ -21,11 +21,12 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi
...
@@ -21,11 +21,12 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.runners.wan.wan_runner
import
MultiModelStruct
,
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
MultiModelStruct
,
WanRunner
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
,
find_torch_model_path
from
lightx2v.utils.utils
import
find_torch_model_path
,
save_to_video
,
vae_to_comfyui_image
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
@
contextmanager
@
contextmanager
def
memory_efficient_inference
():
def
memory_efficient_inference
():
...
@@ -322,7 +323,7 @@ class VideoGenerator:
...
@@ -322,7 +323,7 @@ class VideoGenerator:
if
segment_idx
==
0
:
if
segment_idx
==
0
:
# First segment - create zero frames
# First segment - create zero frames
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
if
self
.
config
.
model_cls
==
'
wan2.2_audio
'
:
if
self
.
config
.
model_cls
==
"
wan2.2_audio
"
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
).
to
(
dtype
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
).
to
(
dtype
)
else
:
else
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
...
@@ -337,7 +338,7 @@ class VideoGenerator:
...
@@ -337,7 +338,7 @@ class VideoGenerator:
else
:
else
:
# Fallback to zeros if prepare_prev_latents fails
# Fallback to zeros if prepare_prev_latents fails
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
if
self
.
config
.
model_cls
==
'
wan2.2_audio
'
:
if
self
.
config
.
model_cls
==
"
wan2.2_audio
"
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
).
to
(
dtype
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
).
to
(
dtype
)
else
:
else
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
),
self
.
config
)[
0
].
to
(
dtype
)
...
@@ -695,7 +696,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -695,7 +696,7 @@ class WanAudioRunner(WanRunner): # type:ignore
num_channels_latents
=
16
num_channels_latents
=
16
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
num_channels_latents
=
self
.
config
.
num_channels_latents
num_channels_latents
=
self
.
config
.
num_channels_latents
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
==
"i2v"
:
self
.
config
.
target_shape
=
(
self
.
config
.
target_shape
=
(
num_channels_latents
,
num_channels_latents
,
...
@@ -813,6 +814,7 @@ class Wan22AudioRunner(WanAudioRunner):
...
@@ -813,6 +814,7 @@ class Wan22AudioRunner(WanAudioRunner):
vae_decoder
=
self
.
load_vae_decoder
()
vae_decoder
=
self
.
load_vae_decoder
()
return
vae_encoder
,
vae_decoder
return
vae_encoder
,
vae_decoder
@
RUNNER_REGISTER
(
"wan2.2_moe_audio"
)
@
RUNNER_REGISTER
(
"wan2.2_moe_audio"
)
class
Wan22MoeAudioRunner
(
WanAudioRunner
):
class
Wan22MoeAudioRunner
(
WanAudioRunner
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
...
lightx2v/models/video_encoders/hf/wan/vae_2_2.py
View file @
9e3680b7
...
@@ -7,7 +7,6 @@ import torch.nn.functional as F
...
@@ -7,7 +7,6 @@ import torch.nn.functional as F
from
einops
import
rearrange
from
einops
import
rearrange
from
lightx2v.utils.utils
import
load_weights
from
lightx2v.utils.utils
import
load_weights
from
loguru
import
logger
__all__
=
[
__all__
=
[
"Wan2_2_VAE"
,
"Wan2_2_VAE"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment