Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
6943aa52
Commit
6943aa52
authored
Jul 30, 2025
by
wangshankun
Browse files
wan2.2 audio driven
parent
fa7aedbe
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
167 additions
and
13 deletions
+167
-13
configs/wan22/wan_i2v_audio.json
configs/wan22/wan_i2v_audio.json
+34
-0
lightx2v/infer.py
lightx2v/infer.py
+6
-2
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+11
-1
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+11
-3
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+47
-5
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+4
-0
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+9
-2
scripts/wan22/run_wan22_moe_i2v_audio.sh
scripts/wan22/run_wan22_moe_i2v_audio.sh
+45
-0
No files found.
configs/wan22/wan_i2v_audio.json
0 → 100755
View file @
6943aa52
{
"infer_steps"
:
6
,
"target_fps"
:
16
,
"video_duration"
:
16
,
"audio_sr"
:
16000
,
"text_len"
:
512
,
"target_video_length"
:
81
,
"target_height"
:
720
,
"target_width"
:
1280
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
[
1.0
,
1.0
],
"sample_shift"
:
5.0
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"model"
,
"boundary"
:
0.900
,
"use_image_encoder"
:
false
,
"use_31_block"
:
false
,
"lora_configs"
:
[
{
"name"
:
"high_noise_model"
,
"path"
:
"/mnt/Text2Video/wuzhuguanyu/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors"
,
"strength"
:
1.0
},
{
"name"
:
"low_noise_model"
,
"path"
:
"/mnt/Text2Video/wuzhuguanyu/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors"
,
"strength"
:
1.0
}
]
}
lightx2v/infer.py
View file @
6943aa52
...
@@ -13,7 +13,7 @@ from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
...
@@ -13,7 +13,7 @@ from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
,
Wan22MoeRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
,
Wan22MoeRunner
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
from
lightx2v.models.runners.wan.wan_audio_runner
import
WanAudioRunner
from
lightx2v.models.runners.wan.wan_audio_runner
import
WanAudioRunner
,
Wan22MoeAudioRunner
from
lightx2v.models.runners.wan.wan_skyreels_v2_df_runner
import
WanSkyreelsV2DFRunner
from
lightx2v.models.runners.wan.wan_skyreels_v2_df_runner
import
WanSkyreelsV2DFRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
...
@@ -42,7 +42,11 @@ def init_runner(config):
...
@@ -42,7 +42,11 @@ def init_runner(config):
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
,
"wan2.2_moe"
],
default
=
"wan2.1"
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
,
"wan2.2_moe"
,
"wan2.2_moe_audio"
],
default
=
"wan2.1"
,
)
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
6943aa52
...
@@ -61,8 +61,18 @@ class WanAudioModel(WanModel):
...
@@ -61,8 +61,18 @@ class WanAudioModel(WanModel):
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
config
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
self
.
scheduler
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
class
Wan22MoeAudioModel
(
WanAudioModel
):
def
_load_ckpt
(
self
,
use_bf16
,
skip_bf16
):
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
self
.
model_path
,
"*.safetensors"
))
weight_dict
=
{}
for
file_path
in
safetensors_files
:
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
use_bf16
,
skip_bf16
)
weight_dict
.
update
(
file_weights
)
return
weight_dict
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
6943aa52
...
@@ -9,7 +9,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -9,7 +9,7 @@ class WanAudioPreInfer(WanPreInfer):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
config
=
config
self
.
task
=
config
[
"task"
]
self
.
task
=
config
[
"task"
]
self
.
freqs
=
torch
.
cat
(
self
.
freqs
=
torch
.
cat
(
[
[
...
@@ -22,6 +22,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -22,6 +22,7 @@ class WanAudioPreInfer(WanPreInfer):
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
text_len
=
config
[
"text_len"
]
self
.
text_len
=
config
[
"text_len"
]
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
def
infer
(
self
,
weights
,
inputs
,
positive
):
def
infer
(
self
,
weights
,
inputs
,
positive
):
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
].
unsqueeze
(
0
)
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
].
unsqueeze
(
0
)
...
@@ -93,13 +94,20 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -93,13 +94,20 @@ class WanAudioPreInfer(WanPreInfer):
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
context
=
weights
.
text_embedding_2
.
apply
(
out
)
context
=
weights
.
text_embedding_2
.
apply
(
out
)
if
self
.
task
==
"i2v"
:
if
self
.
task
==
"i2v"
and
self
.
config
.
get
(
"use_image_encoder"
,
True
)
:
context_clip
=
weights
.
proj_0
.
apply
(
clip_fea
)
context_clip
=
weights
.
proj_0
.
apply
(
clip_fea
)
context_clip
=
weights
.
proj_1
.
apply
(
context_clip
)
context_clip
=
weights
.
proj_1
.
apply
(
context_clip
)
context_clip
=
torch
.
nn
.
functional
.
gelu
(
context_clip
,
approximate
=
"none"
)
context_clip
=
torch
.
nn
.
functional
.
gelu
(
context_clip
,
approximate
=
"none"
)
context_clip
=
weights
.
proj_3
.
apply
(
context_clip
)
context_clip
=
weights
.
proj_3
.
apply
(
context_clip
)
context_clip
=
weights
.
proj_4
.
apply
(
context_clip
)
context_clip
=
weights
.
proj_4
.
apply
(
context_clip
)
if
self
.
clean_cuda_cache
:
del
clip_fea
torch
.
cuda
.
empty_cache
()
context
=
torch
.
concat
([
context_clip
,
context
],
dim
=
0
)
context
=
torch
.
concat
([
context_clip
,
context
],
dim
=
0
)
if
self
.
clean_cuda_cache
:
if
self
.
config
.
get
(
"use_image_encoder"
,
True
):
del
context_clip
torch
.
cuda
.
empty_cache
()
return
(
embed
,
x_grid_sizes
,
(
x
.
squeeze
(
0
),
embed0
.
squeeze
(
0
),
seq_lens
,
self
.
freqs
,
context
,
audio_dit_blocks
),
valid_patch_length
)
return
(
embed
,
x_grid_sizes
,
(
x
.
squeeze
(
0
),
embed0
.
squeeze
(
0
),
seq_lens
,
self
.
freqs
,
context
,
audio_dit_blocks
),
valid_patch_length
)
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
6943aa52
...
@@ -11,11 +11,12 @@ from dataclasses import dataclass
...
@@ -11,11 +11,12 @@ from dataclasses import dataclass
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.models.networks.wan.audio_model
import
WanAudioModel
from
lightx2v.models.networks.wan.audio_model
import
WanAudioModel
,
Wan22MoeAudioModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
.wan_runner
import
MultiModelStruct
from
loguru
import
logger
from
loguru
import
logger
from
einops
import
rearrange
from
einops
import
rearrange
...
@@ -262,7 +263,7 @@ class VideoGenerator:
...
@@ -262,7 +263,7 @@ class VideoGenerator:
if
prev_video
is
None
:
if
prev_video
is
None
:
return
None
return
None
device
=
self
.
model
.
device
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
vae_dtype
=
torch
.
float
vae_dtype
=
torch
.
float
...
@@ -315,7 +316,7 @@ class VideoGenerator:
...
@@ -315,7 +316,7 @@ class VideoGenerator:
self
.
model
.
scheduler
.
reset
()
self
.
model
.
scheduler
.
reset
()
# Prepare previous latents - ALWAYS needed, even for first segment
# Prepare previous latents - ALWAYS needed, even for first segment
device
=
self
.
model
.
device
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
vae_dtype
=
torch
.
float
vae_dtype
=
torch
.
float
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
...
@@ -423,7 +424,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -423,7 +424,7 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adapter
,
audio_adapter_path
,
strict
=
False
)
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adapter
,
audio_adapter_path
,
strict
=
False
)
# Audio encoder
# Audio encoder
device
=
self
.
model
.
device
device
=
torch
.
device
(
"cuda"
)
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
generator
=
torch
.
Generator
(
device
),
weight
=
1.0
)
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
generator
=
torch
.
Generator
(
device
),
weight
=
1.0
)
...
@@ -655,7 +656,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -655,7 +656,7 @@ class WanAudioRunner(WanRunner): # type:ignore
cond_frms
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
config
.
tgt_h
,
config
.
tgt_w
),
mode
=
"bicubic"
)
cond_frms
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
config
.
tgt_h
,
config
.
tgt_w
),
mode
=
"bicubic"
)
# clip encoder
# clip encoder
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
# vae encode
# vae encode
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
...
@@ -684,3 +685,44 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -684,3 +685,44 @@ class WanAudioRunner(WanRunner): # type:ignore
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
ret
return
ret
@
RUNNER_REGISTER
(
"wan2.2_moe_audio"
)
class
Wan22MoeAudioRunner
(
WanAudioRunner
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
load_transformer
(
self
):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model
=
Wan22MoeAudioModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"high_noise_model"
),
self
.
config
,
self
.
init_device
,
)
low_noise_model
=
Wan22MoeAudioModel
(
os
.
path
.
join
(
self
.
config
.
model_path
,
"low_noise_model"
),
self
.
config
,
self
.
init_device
,
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
.
lora_configs
:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
or
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
for
lora_config
in
self
.
config
.
lora_configs
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
if
lora_config
.
name
==
"high_noise_model"
:
lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"
{
lora_config
.
name
}
Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
if
lora_config
.
name
==
"low_noise_model"
:
lora_wrapper
=
WanLoraWrapper
(
low_noise_model
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"
{
lora_config
.
name
}
Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
# XXX: trick
self
.
_audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
return
MultiModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
.
boundary
)
lightx2v/models/runners/wan/wan_runner.py
View file @
6943aa52
...
@@ -305,6 +305,10 @@ class MultiModelStruct:
...
@@ -305,6 +305,10 @@ class MultiModelStruct:
self
.
cur_model_index
=
-
1
self
.
cur_model_index
=
-
1
logger
.
info
(
f
"boundary:
{
self
.
boundary
}
, boundary_timestep:
{
self
.
boundary_timestep
}
"
)
logger
.
info
(
f
"boundary:
{
self
.
boundary
}
, boundary_timestep:
{
self
.
boundary_timestep
}
"
)
@
property
def
device
(
self
):
return
self
.
model
[
self
.
cur_model_index
].
device
def
set_scheduler
(
self
,
shared_scheduler
):
def
set_scheduler
(
self
,
shared_scheduler
):
self
.
scheduler
=
shared_scheduler
self
.
scheduler
=
shared_scheduler
for
model
in
self
.
model
:
for
model
in
self
.
model
:
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
6943aa52
...
@@ -852,13 +852,20 @@ class WanVAE:
...
@@ -852,13 +852,20 @@ class WanVAE:
.
to
(
device
)
.
to
(
device
)
)
)
def
current_device
(
self
):
return
next
(
self
.
model
.
parameters
()).
device
def
to_cpu
(
self
):
def
to_cpu
(
self
):
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
"cpu"
)
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
"cpu"
)
self
.
model
=
self
.
model
.
to
(
"cpu"
)
self
.
model
=
self
.
model
.
to
(
"cpu"
)
self
.
mean
=
self
.
mean
.
cpu
()
self
.
mean
=
self
.
mean
.
cpu
()
self
.
inv_std
=
self
.
inv_std
.
cpu
()
self
.
inv_std
=
self
.
inv_std
.
cpu
()
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
def
to_cuda
(
self
):
def
to_cuda
(
self
):
self
.
model
.
encoder
=
self
.
model
.
encoder
.
to
(
"cuda"
)
self
.
model
.
decoder
=
self
.
model
.
decoder
.
to
(
"cuda"
)
self
.
model
=
self
.
model
.
to
(
"cuda"
)
self
.
model
=
self
.
model
.
to
(
"cuda"
)
self
.
mean
=
self
.
mean
.
cuda
()
self
.
mean
=
self
.
mean
.
cuda
()
self
.
inv_std
=
self
.
inv_std
.
cuda
()
self
.
inv_std
=
self
.
inv_std
.
cuda
()
...
@@ -872,9 +879,9 @@ class WanVAE:
...
@@ -872,9 +879,9 @@ class WanVAE:
self
.
to_cuda
()
self
.
to_cuda
()
if
self
.
use_tiling
:
if
self
.
use_tiling
:
out
=
[
self
.
model
.
tiled_encode
(
u
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
for
u
in
videos
]
out
=
[
self
.
model
.
tiled_encode
(
u
.
unsqueeze
(
0
)
.
to
(
self
.
current_device
())
,
self
.
scale
).
float
().
squeeze
(
0
)
for
u
in
videos
]
else
:
else
:
out
=
[
self
.
model
.
encode
(
u
.
unsqueeze
(
0
),
self
.
scale
).
float
().
squeeze
(
0
)
for
u
in
videos
]
out
=
[
self
.
model
.
encode
(
u
.
unsqueeze
(
0
)
.
to
(
self
.
current_device
())
,
self
.
scale
).
float
().
squeeze
(
0
)
for
u
in
videos
]
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
self
.
to_cpu
()
self
.
to_cpu
()
...
...
scripts/wan22/run_wan22_moe_i2v_audio.sh
0 → 100755
View file @
6943aa52
#!/bin/bash
# set path and first
lightx2v_path
=
"/mnt/Text2Video2/wangshankun/lightx2v"
model_path
=
"/mnt/Text2Video/wangshankun/HF_Cache/Wan2.2-I2V-A14B"
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using default value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
export
TORCH_CUDA_ARCH_LIST
=
"9.0"
export
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments:True
#-m debugpy --wait-for-client --listen 0.0.0.0:15684 \
python
\
-m
lightx2v.infer
\
--model_cls
wan2.2_moe_audio
\
--task
i2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/wan22/wan_i2v_audio.json
\
--prompt
"The video features a old lady is saying something and knitting a sweater."
\
--negative_prompt
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
\
--image_path
${
lightx2v_path
}
/assets/inputs/audio/15.png
\
--audio_path
${
lightx2v_path
}
/assets/inputs/audio/15.wav
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan22_moe_i2v_audio.mp4
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment