Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
d8454a2b
Commit
d8454a2b
authored
Aug 25, 2025
by
helloyongyang
Browse files
Refactor runners
parent
2054eca3
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
544 additions
and
774 deletions
+544
-774
configs/audio_driven/wan_i2v_audio_quant.json
configs/audio_driven/wan_i2v_audio_quant.json
+3
-1
lightx2v/models/input_encoders/hf/q_linear.py
lightx2v/models/input_encoders/hf/q_linear.py
+57
-0
lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
...tx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
+21
-184
lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
...tx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
+29
-0
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+5
-0
lightx2v/models/networks/wan/infer/audio/post_infer.py
lightx2v/models/networks/wan/infer/audio/post_infer.py
+2
-22
lightx2v/models/networks/wan/infer/audio/pre_infer.py
lightx2v/models/networks/wan/infer/audio/pre_infer.py
+15
-24
lightx2v/models/networks/wan/infer/audio/transformer_infer.py
...tx2v/models/networks/wan/infer/audio/transformer_infer.py
+80
-6
lightx2v/models/networks/wan/infer/module_io.py
lightx2v/models/networks/wan/infer/module_io.py
+5
-5
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
+2
-2
lightx2v/models/runners/base_runner.py
lightx2v/models/runners/base_runner.py
+32
-50
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+37
-25
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+194
-382
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+8
-12
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+16
-59
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
...2v/models/schedulers/wan/changing_resolution/scheduler.py
+1
-0
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+1
-2
tools/convert/quant_adapter.py
tools/convert/quant_adapter.py
+36
-0
No files found.
configs/audio_driven/wan_i2v_audio_quant.json
View file @
d8454a2b
...
...
@@ -18,5 +18,7 @@
"dit_quantized_ckpt"
:
"/path/to/Wan2.1-R2V721-Audio-14B-720P/fp8"
,
"mm_config"
:
{
"mm_type"
:
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
}
},
"adapter_quantized"
:
true
,
"adapter_quant_scheme"
:
"fp8"
}
lightx2v/models/input_encoders/hf/q_linear.py
View file @
d8454a2b
...
...
@@ -6,6 +6,11 @@ try:
except
ModuleNotFoundError
:
ops
=
None
try
:
import
sgl_kernel
except
ImportError
:
sgl_kernel
=
None
try
:
from
torchao.quantization.utils
import
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
except
ModuleNotFoundError
:
...
...
@@ -117,6 +122,58 @@ class VllmQuantLinearFp8(nn.Module):
return
self
class
SglQuantLinearFp8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
register_buffer
(
"weight"
,
torch
.
empty
((
out_features
,
in_features
),
dtype
=
torch
.
float8_e4m3fn
))
self
.
register_buffer
(
"weight_scale"
,
torch
.
empty
((
out_features
,
1
),
dtype
=
torch
.
float32
))
if
bias
:
self
.
register_buffer
(
"bias"
,
torch
.
empty
(
out_features
,
dtype
=
dtype
))
else
:
self
.
register_buffer
(
"bias"
,
None
)
def
act_quant_func
(
self
,
x
):
m
,
k
=
x
.
shape
input_tensor_quant
=
torch
.
empty
((
m
,
k
),
dtype
=
torch
.
float8_e4m3fn
,
device
=
"cuda"
,
requires_grad
=
False
)
input_tensor_scale
=
torch
.
empty
((
m
,
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
requires_grad
=
False
)
sgl_kernel
.
sgl_per_token_quant_fp8
(
x
,
input_tensor_quant
,
input_tensor_scale
)
return
input_tensor_quant
,
input_tensor_scale
def
forward
(
self
,
input_tensor
):
input_tensor
=
input_tensor
.
squeeze
(
0
)
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
0
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
sgl_kernel
.
fp8_scaled_mm
(
input_tensor_quant
,
self
.
weight
.
t
(),
input_tensor_scale
,
self
.
weight_scale
,
dtype
,
bias
=
self
.
bias
,
)
return
output_tensor
.
unsqueeze
(
0
)
def
_apply
(
self
,
fn
):
for
module
in
self
.
children
():
module
.
_apply
(
fn
)
def
maybe_cast
(
t
):
if
t
is
not
None
and
t
.
device
!=
fn
(
t
).
device
:
return
fn
(
t
)
return
t
self
.
weight
=
maybe_cast
(
self
.
weight
)
self
.
weight_scale
=
maybe_cast
(
self
.
weight_scale
)
self
.
bias
=
maybe_cast
(
self
.
bias
)
return
self
class
TorchaoQuantLinearInt8
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
=
True
,
dtype
=
torch
.
bfloat16
):
super
().
__init__
()
...
...
lightx2v/models/
networks/wan
/audio_adapter.py
→
lightx2v/models/
input_encoders/hf/seko_audio
/audio_adapter.py
View file @
d8454a2b
...
...
@@ -13,9 +13,8 @@ import torch.nn.functional as F
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
einops
import
rearrange
from
loguru
import
logger
from
transformers
import
AutoModel
from
lightx2v.
utils.envs
import
*
from
lightx2v.
models.input_encoders.hf.q_linear
import
SglQuantLinearFp8
def
load_safetensors
(
in_path
:
str
):
...
...
@@ -84,8 +83,6 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
for
buffer
in
model
.
buffers
():
dist
.
broadcast
(
buffer
.
data
,
src
=
0
)
return
model
.
to
(
dtype
=
GET_DTYPE
())
def
linear_interpolation
(
features
,
output_len
:
int
):
features
=
features
.
transpose
(
1
,
2
)
...
...
@@ -120,7 +117,7 @@ def get_q_lens_audio_range(
class
PerceiverAttentionCA
(
nn
.
Module
):
def
__init__
(
self
,
dim_head
=
128
,
heads
=
16
,
kv_dim
=
2048
,
adaLN
:
bool
=
False
):
def
__init__
(
self
,
dim_head
=
128
,
heads
=
16
,
kv_dim
=
2048
,
adaLN
:
bool
=
False
,
quantized
=
False
,
quant_scheme
=
None
):
super
().
__init__
()
self
.
dim_head
=
dim_head
self
.
heads
=
heads
...
...
@@ -129,9 +126,17 @@ class PerceiverAttentionCA(nn.Module):
self
.
norm_kv
=
nn
.
LayerNorm
(
kv_dim
)
self
.
norm_q
=
nn
.
LayerNorm
(
inner_dim
,
elementwise_affine
=
not
adaLN
)
self
.
to_q
=
nn
.
Linear
(
inner_dim
,
inner_dim
)
self
.
to_kv
=
nn
.
Linear
(
kv_dim
,
inner_dim
*
2
)
self
.
to_out
=
nn
.
Linear
(
inner_dim
,
inner_dim
)
if
quantized
:
if
quant_scheme
==
"fp8"
:
self
.
to_q
=
SglQuantLinearFp8
(
inner_dim
,
inner_dim
)
self
.
to_kv
=
nn
.
Linear
(
kv_dim
,
inner_dim
*
2
)
self
.
to_out
=
SglQuantLinearFp8
(
inner_dim
,
inner_dim
)
else
:
raise
ValueError
(
f
"Unsupported quant_scheme:
{
quant_scheme
}
"
)
else
:
self
.
to_q
=
nn
.
Linear
(
inner_dim
,
inner_dim
)
self
.
to_kv
=
nn
.
Linear
(
kv_dim
,
inner_dim
*
2
)
self
.
to_out
=
nn
.
Linear
(
inner_dim
,
inner_dim
)
if
adaLN
:
self
.
shift_scale_gate
=
nn
.
Parameter
(
torch
.
randn
(
1
,
3
,
inner_dim
)
/
inner_dim
**
0.5
)
else
:
...
...
@@ -151,7 +156,7 @@ class PerceiverAttentionCA(nn.Module):
shift
=
shift
.
transpose
(
0
,
1
)
gate
=
gate
.
transpose
(
0
,
1
)
latents
=
norm_q
*
(
1
+
scale
)
+
shift
q
=
self
.
to_q
(
latents
.
to
(
GET_DTYPE
())
)
q
=
self
.
to_q
(
latents
)
k
,
v
=
self
.
to_kv
(
x
).
chunk
(
2
,
dim
=-
1
)
q
=
rearrange
(
q
,
"B L (H C) -> (B L) H C"
,
H
=
self
.
heads
)
k
=
rearrange
(
k
,
"B T L (H C) -> (B T L) H C"
,
H
=
self
.
heads
)
...
...
@@ -258,6 +263,8 @@ class AudioAdapter(nn.Module):
mlp_dims
:
tuple
=
(
1024
,
1024
,
32
*
768
),
time_freq_dim
:
int
=
256
,
projection_transformer_layers
:
int
=
4
,
quantized
:
bool
=
False
,
quant_scheme
:
str
=
None
,
):
super
().
__init__
()
self
.
audio_proj
=
AudioProjection
(
...
...
@@ -280,6 +287,8 @@ class AudioAdapter(nn.Module):
heads
=
num_attention_heads
,
kv_dim
=
mlp_dims
[
-
1
]
//
num_tokens
,
adaLN
=
time_freq_dim
>
0
,
quantized
=
quantized
,
quant_scheme
=
quant_scheme
,
)
for
_
in
range
(
ca_num
)
]
...
...
@@ -298,181 +307,9 @@ class AudioAdapter(nn.Module):
audio_feature
=
rearrange
(
audio_feature
,
"B (T S) N C -> B T (S N) C"
,
S
=
4
)
return
audio_feature
def
forward
(
self
,
audio_feat
:
torch
.
Tensor
,
timestep
:
torch
.
Tensor
,
latent_frame
:
int
,
weight
:
float
=
1.0
,
seq_p_group
=
None
):
def
modify_hidden_states
(
hidden_states
,
grid_sizes
,
ca_block
:
PerceiverAttentionCA
,
x
,
t_emb
,
dtype
,
weight
,
seq_p_group
):
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
latent_frame does not include the reference images so that the
audios and hidden_states are strictly aligned
"""
if
len
(
hidden_states
.
shape
)
==
2
:
# 扩展batchsize dim
hidden_states
=
hidden_states
.
unsqueeze
(
0
)
# bs = 1
t
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
n_tokens
=
t
*
h
*
w
ori_dtype
=
hidden_states
.
dtype
device
=
hidden_states
.
device
bs
,
n_tokens_per_rank
=
hidden_states
.
shape
[:
2
]
if
seq_p_group
is
not
None
:
sp_size
=
dist
.
get_world_size
(
seq_p_group
)
sp_rank
=
dist
.
get_rank
(
seq_p_group
)
else
:
sp_size
=
1
sp_rank
=
0
tail_length
=
n_tokens_per_rank
*
sp_size
-
n_tokens
n_unused_ranks
=
tail_length
//
n_tokens_per_rank
if
sp_rank
>
sp_size
-
n_unused_ranks
-
1
:
n_query_tokens
=
0
elif
sp_rank
==
sp_size
-
n_unused_ranks
-
1
:
n_query_tokens
=
n_tokens_per_rank
-
tail_length
%
n_tokens_per_rank
else
:
n_query_tokens
=
n_tokens_per_rank
if
n_query_tokens
>
0
:
hidden_states_aligned
=
hidden_states
[:,
:
n_query_tokens
]
hidden_states_tail
=
hidden_states
[:,
n_query_tokens
:]
else
:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned
=
hidden_states
[:,
:
1
]
hidden_states_tail
=
hidden_states
[:,
1
:]
q_lens
,
t0
,
t1
=
get_q_lens_audio_range
(
batchsize
=
bs
,
n_tokens_per_rank
=
n_tokens_per_rank
,
n_query_tokens
=
n_query_tokens
,
n_tokens_per_frame
=
h
*
w
,
sp_rank
=
sp_rank
)
q_lens
=
torch
.
tensor
(
q_lens
,
device
=
device
,
dtype
=
torch
.
int32
)
"""
processing audio features in sp_state can be moved outside.
"""
x
=
x
[:,
t0
:
t1
]
x
=
x
.
to
(
dtype
)
k_lens
=
torch
.
tensor
([
self
.
num_tokens_x4
]
*
(
t1
-
t0
)
*
bs
,
device
=
device
,
dtype
=
torch
.
int32
)
assert
q_lens
.
shape
==
k_lens
.
shape
# ca_block:CrossAttention函数
residual
=
ca_block
(
x
,
hidden_states_aligned
,
t_emb
,
q_lens
,
k_lens
)
*
weight
residual
=
residual
.
to
(
ori_dtype
)
# audio做了CrossAttention之后以Residual的方式注入
if
n_query_tokens
==
0
:
residual
=
residual
*
0.0
hidden_states
=
torch
.
cat
([
hidden_states_aligned
+
residual
,
hidden_states_tail
],
dim
=
1
)
if
len
(
hidden_states
.
shape
)
==
3
:
#
hidden_states
=
hidden_states
.
squeeze
(
0
)
# bs = 1
return
hidden_states
@
torch
.
no_grad
()
def
forward_audio_proj
(
self
,
audio_feat
,
latent_frame
):
x
=
self
.
audio_proj
(
audio_feat
,
latent_frame
)
x
=
self
.
rearange_audio_features
(
x
)
x
=
x
+
self
.
audio_pe
if
self
.
time_embedding
is
not
None
:
t_emb
=
self
.
time_embedding
(
timestep
).
unflatten
(
1
,
(
3
,
-
1
))
else
:
t_emb
=
torch
.
zeros
((
len
(
x
),
3
,
self
.
dim
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
ret_dict
=
{}
for
block_idx
,
base_idx
in
enumerate
(
range
(
0
,
self
.
base_num_layers
,
self
.
interval
)):
block_dict
=
{
"kwargs"
:
{
"ca_block"
:
self
.
ca
[
block_idx
],
"x"
:
x
,
"weight"
:
weight
,
"t_emb"
:
t_emb
,
"dtype"
:
x
.
dtype
,
"seq_p_group"
:
seq_p_group
,
},
"modify_func"
:
modify_hidden_states
,
}
ret_dict
[
base_idx
]
=
block_dict
return
ret_dict
@
classmethod
def
from_transformer
(
cls
,
transformer
,
audio_feature_dim
:
int
=
1024
,
interval
:
int
=
1
,
time_freq_dim
:
int
=
256
,
projection_transformer_layers
:
int
=
4
,
):
num_attention_heads
=
transformer
.
config
[
"num_heads"
]
base_num_layers
=
transformer
.
config
[
"num_layers"
]
attention_head_dim
=
transformer
.
config
[
"dim"
]
//
num_attention_heads
audio_adapter
=
AudioAdapter
(
attention_head_dim
,
num_attention_heads
,
base_num_layers
,
interval
=
interval
,
audio_feature_dim
=
audio_feature_dim
,
time_freq_dim
=
time_freq_dim
,
projection_transformer_layers
=
projection_transformer_layers
,
mlp_dims
=
(
1024
,
1024
,
32
*
audio_feature_dim
),
)
return
audio_adapter
def
get_fsdp_wrap_module_list
(
self
,
):
ret_list
=
list
(
self
.
ca
)
return
ret_list
def
enable_gradient_checkpointing
(
self
,
):
pass
class
AudioAdapterPipe
:
def
__init__
(
self
,
audio_adapter
:
AudioAdapter
,
audio_encoder_repo
:
str
=
"microsoft/wavlm-base-plus"
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
tgt_fps
:
int
=
15
,
weight
:
float
=
1.0
,
cpu_offload
:
bool
=
False
,
seq_p_group
=
None
,
)
->
None
:
self
.
seq_p_group
=
seq_p_group
self
.
audio_adapter
=
audio_adapter
self
.
dtype
=
dtype
self
.
audio_encoder_dtype
=
torch
.
float16
self
.
cpu_offload
=
cpu_offload
##音频编码器
self
.
audio_encoder
=
AutoModel
.
from_pretrained
(
audio_encoder_repo
)
self
.
audio_encoder
.
eval
()
self
.
audio_encoder
.
to
(
device
,
self
.
audio_encoder_dtype
)
self
.
tgt_fps
=
tgt_fps
self
.
weight
=
weight
if
"base"
in
audio_encoder_repo
:
self
.
audio_feature_dim
=
768
else
:
self
.
audio_feature_dim
=
1024
def
update_model
(
self
,
audio_adapter
):
self
.
audio_adapter
=
audio_adapter
def
__call__
(
self
,
audio_input_feat
,
timestep
,
latent_shape
:
tuple
,
dropout_cond
:
callable
=
None
):
# audio_input_feat is from AudioPreprocessor
latent_frame
=
latent_shape
[
2
]
if
len
(
audio_input_feat
.
shape
)
==
1
:
# 扩展batchsize = 1
audio_input_feat
=
audio_input_feat
.
unsqueeze
(
0
)
latent_frame
=
latent_shape
[
1
]
video_frame
=
(
latent_frame
-
1
)
*
4
+
1
audio_length
=
int
(
50
/
self
.
tgt_fps
*
video_frame
)
with
torch
.
no_grad
():
try
:
if
self
.
cpu_offload
:
self
.
audio_encoder
=
self
.
audio_encoder
.
to
(
"cuda"
)
audio_feat
=
self
.
audio_encoder
(
audio_input_feat
.
to
(
self
.
audio_encoder_dtype
),
return_dict
=
True
).
last_hidden_state
if
self
.
cpu_offload
:
self
.
audio_encoder
=
self
.
audio_encoder
.
to
(
"cpu"
)
except
Exception
as
err
:
audio_feat
=
torch
.
rand
(
1
,
audio_length
,
self
.
audio_feature_dim
).
to
(
"cuda"
)
print
(
err
)
audio_feat
=
audio_feat
.
to
(
self
.
dtype
)
if
dropout_cond
is
not
None
:
audio_feat
=
dropout_cond
(
audio_feat
)
return
self
.
audio_adapter
(
audio_feat
=
audio_feat
,
timestep
=
timestep
,
latent_frame
=
latent_frame
,
weight
=
self
.
weight
,
seq_p_group
=
self
.
seq_p_group
)
return
x
lightx2v/models/input_encoders/hf/seko_audio/audio_encoder.py
0 → 100644
View file @
d8454a2b
import
torch
from
transformers
import
AutoFeatureExtractor
,
AutoModel
from
lightx2v.utils.envs
import
*
class
SekoAudioEncoderModel
:
def
__init__
(
self
,
model_path
,
audio_sr
):
self
.
model_path
=
model_path
self
.
audio_sr
=
audio_sr
self
.
load
()
def
load
(
self
):
self
.
audio_feature_extractor
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
model_path
)
self
.
audio_feature_encoder
=
AutoModel
.
from_pretrained
(
self
.
model_path
)
self
.
audio_feature_encoder
.
eval
()
self
.
audio_feature_encoder
.
to
(
GET_DTYPE
())
def
to_cpu
(
self
):
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
"cpu"
)
def
to_cuda
(
self
):
self
.
audio_feature_encoder
=
self
.
audio_feature_encoder
.
to
(
"cuda"
)
@
torch
.
no_grad
()
def
infer
(
self
,
audio_segment
):
audio_feat
=
self
.
audio_feature_extractor
(
audio_segment
,
sampling_rate
=
self
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
to
(
self
.
audio_feature_encoder
.
device
).
to
(
dtype
=
GET_DTYPE
())
audio_feat
=
self
.
audio_feature_encoder
(
audio_feat
,
return_dict
=
True
).
last_hidden_state
return
audio_feat
lightx2v/models/networks/wan/audio_model.py
View file @
d8454a2b
...
...
@@ -26,6 +26,11 @@ class WanAudioModel(WanModel):
self
.
post_infer_class
=
WanAudioPostInfer
self
.
transformer_infer_class
=
WanAudioTransformerInfer
def
set_audio_adapter
(
self
,
audio_adapter
):
self
.
audio_adapter
=
audio_adapter
self
.
pre_infer
.
set_audio_adapter
(
self
.
audio_adapter
)
self
.
transformer_infer
.
set_audio_adapter
(
self
.
audio_adapter
)
class
Wan22MoeAudioModel
(
WanAudioModel
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
...
...
lightx2v/models/networks/wan/infer/audio/post_infer.py
View file @
d8454a2b
import
math
import
torch
from
lightx2v.models.networks.wan.infer.post_infer
import
WanPostInfer
...
...
@@ -8,32 +6,14 @@ from lightx2v.utils.envs import *
class
WanAudioPostInfer
(
WanPostInfer
):
def
__init__
(
self
,
config
):
self
.
out_dim
=
config
[
"out_dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
super
().
__init__
(
config
)
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
x
[:
,
:
pre_infer_out
.
valid_patch_length
]
x
=
x
[:
pre_infer_out
.
seq_lens
[
0
]
]
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
if
self
.
clean_cuda_cache
:
torch
.
cuda
.
empty_cache
()
return
[
u
.
float
()
for
u
in
x
]
def
unpatchify
(
self
,
x
,
grid_sizes
):
x
=
x
.
unsqueeze
(
0
)
c
=
self
.
out_dim
out
=
[]
for
u
,
v
in
zip
(
x
,
grid_sizes
.
tolist
()):
u
=
u
[:
math
.
prod
(
v
)].
view
(
*
v
,
*
self
.
patch_size
,
c
)
u
=
torch
.
einsum
(
"fhwpqrc->cfphqwr"
,
u
)
u
=
u
.
reshape
(
c
,
*
[
i
*
j
for
i
,
j
in
zip
(
v
,
self
.
patch_size
)])
out
.
append
(
u
)
return
out
lightx2v/models/networks/wan/infer/audio/pre_infer.py
View file @
d8454a2b
...
...
@@ -35,6 +35,9 @@ class WanAudioPreInfer(WanPreInfer):
else
:
self
.
sp_size
=
1
def
set_audio_adapter
(
self
,
audio_adapter
):
self
.
audio_adapter
=
audio_adapter
def
infer
(
self
,
weights
,
inputs
):
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
...
...
@@ -48,7 +51,7 @@ class WanAudioPreInfer(WanPreInfer):
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
],
dim
=
1
)
hidden_states
=
hidden_states
.
squeeze
(
0
)
x
=
[
hidden_states
]
x
=
hidden_states
t
=
torch
.
stack
([
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]])
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
...
...
@@ -61,31 +64,23 @@ class WanAudioPreInfer(WanPreInfer):
temp_ts
=
torch
.
cat
([
temp_ts
,
temp_ts
.
new_ones
(
max_seq_len
-
temp_ts
.
size
(
0
))
*
t
])
t
=
temp_ts
.
unsqueeze
(
0
)
audio_dit_blocks
=
[]
audio_encoder_output
=
inputs
[
"audio_encoder_output"
]
audio_model_input
=
{
"audio_input_feat"
:
audio_encoder_output
.
to
(
hidden_states
.
device
),
"latent_shape"
:
hidden_states
.
shape
,
"timestep"
:
t
,
}
audio_dit_blocks
.
append
(
inputs
[
"audio_adapter_pipe"
](
**
audio_model_input
))
# audio_dit_blocks = None##Debug Drop Audio
t_emb
=
self
.
audio_adapter
.
time_embedding
(
t
).
unflatten
(
1
,
(
3
,
-
1
))
if
self
.
scheduler
.
infer_condition
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
else
:
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
seq_len
=
self
.
scheduler
.
seq_len
#
seq_len = self.scheduler.seq_len
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
ref_image_encoder
=
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
].
to
(
self
.
scheduler
.
latents
.
dtype
)
batch_size
=
len
(
x
)
num_channels
,
_
,
height
,
width
=
x
[
0
]
.
shape
#
batch_size = len(x)
num_channels
,
_
,
height
,
width
=
x
.
shape
_
,
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
if
ref_num_channels
!=
num_channels
:
zero_padding
=
torch
.
zeros
(
(
batch_size
,
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
(
1
,
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
dtype
=
self
.
scheduler
.
latents
.
dtype
,
device
=
self
.
scheduler
.
latents
.
device
,
)
...
...
@@ -93,13 +88,10 @@ class WanAudioPreInfer(WanPreInfer):
y
=
list
(
torch
.
unbind
(
ref_image_encoder
,
dim
=
0
))
# 第一个batch维度变成list
# embeddings
x
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
x
]
x_grid_sizes
=
torch
.
stack
([
torch
.
tensor
(
u
.
shape
[
2
:],
dtype
=
torch
.
long
)
for
u
in
x
])
x
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
)
for
u
in
x
]
seq_lens
=
torch
.
tensor
([
u
.
size
(
1
)
for
u
in
x
],
dtype
=
torch
.
long
).
cuda
()
assert
seq_lens
.
max
()
<=
seq_len
x
=
torch
.
cat
([
torch
.
cat
([
u
,
u
.
new_zeros
(
1
,
seq_len
-
u
.
size
(
1
),
u
.
size
(
2
))],
dim
=
1
)
for
u
in
x
])
valid_patch_length
=
x
[
0
].
size
(
0
)
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
grid_sizes
=
torch
.
tensor
(
x
.
shape
[
2
:],
dtype
=
torch
.
long
).
unsqueeze
(
0
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
long
).
cuda
().
unsqueeze
(
0
)
y
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
y
]
# y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
...
...
@@ -169,12 +161,11 @@ class WanAudioPreInfer(WanPreInfer):
return
WanPreInferModuleOutput
(
embed
=
embed
,
grid_sizes
=
x_
grid_sizes
,
grid_sizes
=
grid_sizes
,
x
=
x
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
audio_dit_blocks
=
audio_dit_blocks
,
valid_patch_length
=
valid_patch_length
,
adapter_output
=
{
"audio_encoder_output"
:
inputs
[
"audio_encoder_output"
],
"t_emb"
:
t_emb
},
)
lightx2v/models/networks/wan/infer/audio/transformer_infer.py
View file @
d8454a2b
import
torch
import
torch.distributed
as
dist
from
lightx2v.models.input_encoders.hf.seko_audio.audio_adapter
import
get_q_lens_audio_range
from
lightx2v.models.networks.wan.infer.offload.transformer_infer
import
WanOffloadTransformerInfer
from
lightx2v.models.networks.wan.infer.utils
import
compute_freqs_audio
,
compute_freqs_audio_dist
...
...
@@ -5,7 +9,13 @@ from lightx2v.models.networks.wan.infer.utils import compute_freqs_audio, comput
class
WanAudioTransformerInfer
(
WanOffloadTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
num_tokens
=
32
self
.
num_tokens_x4
=
self
.
num_tokens
*
4
def
set_audio_adapter
(
self
,
audio_adapter
):
self
.
audio_adapter
=
audio_adapter
@
torch
.
no_grad
()
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
self
.
config
[
"seq_parallel"
]:
freqs_i
=
compute_freqs_audio_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
...
...
@@ -13,13 +23,77 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
@
torch
.
no_grad
()
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
pre_infer_out
):
x
=
super
().
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
# Apply audio_dit if available
if
pre_infer_out
.
audio_dit_blocks
is
not
None
and
hasattr
(
self
,
"block_idx"
):
for
ipa_out
in
pre_infer_out
.
audio_dit_blocks
:
if
self
.
block_idx
in
ipa_out
:
cur_modify
=
ipa_out
[
self
.
block_idx
]
x
=
cur_modify
[
"modify_func"
](
x
,
pre_infer_out
.
grid_sizes
,
**
cur_modify
[
"kwargs"
])
x
=
self
.
modify_hidden_states
(
hidden_states
=
x
,
grid_sizes
=
pre_infer_out
.
grid_sizes
,
ca_block
=
self
.
audio_adapter
.
ca
[
self
.
block_idx
],
audio_encoder_output
=
pre_infer_out
.
adapter_output
[
"audio_encoder_output"
],
t_emb
=
pre_infer_out
.
adapter_output
[
"t_emb"
],
weight
=
1.0
,
seq_p_group
=
self
.
seq_p_group
,
)
return
x
@
torch
.
no_grad
()
def
modify_hidden_states
(
self
,
hidden_states
,
grid_sizes
,
ca_block
,
audio_encoder_output
,
t_emb
,
weight
,
seq_p_group
):
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
latent_frame does not include the reference images so that the
audios and hidden_states are strictly aligned
"""
if
len
(
hidden_states
.
shape
)
==
2
:
# 扩展batchsize dim
hidden_states
=
hidden_states
.
unsqueeze
(
0
)
# bs = 1
t
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
n_tokens
=
t
*
h
*
w
ori_dtype
=
hidden_states
.
dtype
device
=
hidden_states
.
device
bs
,
n_tokens_per_rank
=
hidden_states
.
shape
[:
2
]
if
seq_p_group
is
not
None
:
sp_size
=
dist
.
get_world_size
(
seq_p_group
)
sp_rank
=
dist
.
get_rank
(
seq_p_group
)
else
:
sp_size
=
1
sp_rank
=
0
tail_length
=
n_tokens_per_rank
*
sp_size
-
n_tokens
n_unused_ranks
=
tail_length
//
n_tokens_per_rank
if
sp_rank
>
sp_size
-
n_unused_ranks
-
1
:
n_query_tokens
=
0
elif
sp_rank
==
sp_size
-
n_unused_ranks
-
1
:
n_query_tokens
=
n_tokens_per_rank
-
tail_length
%
n_tokens_per_rank
else
:
n_query_tokens
=
n_tokens_per_rank
if
n_query_tokens
>
0
:
hidden_states_aligned
=
hidden_states
[:,
:
n_query_tokens
]
hidden_states_tail
=
hidden_states
[:,
n_query_tokens
:]
else
:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned
=
hidden_states
[:,
:
1
]
hidden_states_tail
=
hidden_states
[:,
1
:]
q_lens
,
t0
,
t1
=
get_q_lens_audio_range
(
batchsize
=
bs
,
n_tokens_per_rank
=
n_tokens_per_rank
,
n_query_tokens
=
n_query_tokens
,
n_tokens_per_frame
=
h
*
w
,
sp_rank
=
sp_rank
)
q_lens
=
torch
.
tensor
(
q_lens
,
device
=
device
,
dtype
=
torch
.
int32
)
"""
processing audio features in sp_state can be moved outside.
"""
audio_encoder_output
=
audio_encoder_output
[:,
t0
:
t1
]
k_lens
=
torch
.
tensor
([
self
.
num_tokens_x4
]
*
(
t1
-
t0
)
*
bs
,
device
=
device
,
dtype
=
torch
.
int32
)
assert
q_lens
.
shape
==
k_lens
.
shape
# ca_block:CrossAttention函数
residual
=
ca_block
(
audio_encoder_output
,
hidden_states_aligned
,
t_emb
,
q_lens
,
k_lens
)
*
weight
residual
=
residual
.
to
(
ori_dtype
)
# audio做了CrossAttention之后以Residual的方式注入
if
n_query_tokens
==
0
:
residual
=
residual
*
0.0
hidden_states
=
torch
.
cat
([
hidden_states_aligned
+
residual
,
hidden_states_tail
],
dim
=
1
)
if
len
(
hidden_states
.
shape
)
==
3
:
#
hidden_states
=
hidden_states
.
squeeze
(
0
)
# bs = 1
return
hidden_states
lightx2v/models/networks/wan/infer/module_io.py
View file @
d8454a2b
from
dataclasses
import
dataclass
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
Dict
import
torch
@
dataclass
class
WanPreInferModuleOutput
:
# wan base model
embed
:
torch
.
Tensor
grid_sizes
:
torch
.
Tensor
x
:
torch
.
Tensor
...
...
@@ -13,7 +14,6 @@ class WanPreInferModuleOutput:
seq_lens
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
audio_dit_blocks
:
List
[
Any
]
=
None
valid_patch_length
:
Optional
[
int
]
=
None
hints
:
List
[
Any
]
=
None
context_scale
:
float
=
1.0
# wan adapter model
adapter_output
:
Dict
[
str
,
Any
]
=
None
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
View file @
d8454a2b
...
...
@@ -9,7 +9,7 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
self
.
vace_blocks_mapping
=
{
orig_idx
:
seq_idx
for
seq_idx
,
orig_idx
in
enumerate
(
self
.
config
.
vace_layers
)}
def
infer
(
self
,
weights
,
pre_infer_out
):
pre_infer_out
.
hints
=
self
.
infer_vace
(
weights
,
pre_infer_out
)
pre_infer_out
.
adapter_output
[
"
hints
"
]
=
self
.
infer_vace
(
weights
,
pre_infer_out
)
x
=
self
.
infer_main_blocks
(
weights
,
pre_infer_out
)
return
self
.
infer_non_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
...
...
@@ -40,6 +40,6 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
x
=
super
().
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
if
self
.
infer_state
==
"base"
and
self
.
block_idx
in
self
.
vace_blocks_mapping
:
hint_idx
=
self
.
vace_blocks_mapping
[
self
.
block_idx
]
x
=
x
+
pre_infer_out
.
hints
[
hint_idx
]
*
pre_infer_out
.
context_scale
x
=
x
+
pre_infer_out
.
adapter_output
[
"
hints
"
]
[
hint_idx
]
*
pre_infer_out
.
adapter_output
.
get
(
"
context_scale
"
,
1.0
)
return
x
lightx2v/models/runners/base_runner.py
View file @
d8454a2b
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Protocol
,
Tuple
,
Union
from
abc
import
ABC
from
lightx2v.utils.utils
import
save_videos_grid
class
TransformerModel
(
Protocol
):
"""Protocol for transformer models"""
def
set_scheduler
(
self
,
scheduler
:
Any
)
->
None
:
...
def
scheduler
(
self
)
->
Any
:
...
class
TextEncoderModel
(
Protocol
):
"""Protocol for text encoder models"""
def
infer
(
self
,
texts
:
List
[
str
],
config
:
Dict
[
str
,
Any
])
->
Any
:
...
class
ImageEncoderModel
(
Protocol
):
"""Protocol for image encoder models"""
def
encode
(
self
,
image
:
Any
)
->
Any
:
...
class
VAEModel
(
Protocol
):
"""Protocol for VAE models"""
def
encode
(
self
,
image
:
Any
)
->
Tuple
[
Any
,
Dict
[
str
,
Any
]]:
...
def
decode
(
self
,
latents
:
Any
,
generator
:
Optional
[
Any
]
=
None
,
config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
Any
:
...
class
BaseRunner
(
ABC
):
"""Abstract base class for all Runners
Defines interface methods that all subclasses must implement
"""
def
__init__
(
self
,
config
:
Dict
[
str
,
Any
]
):
def
__init__
(
self
,
config
):
self
.
config
=
config
@
abstractmethod
def
load_transformer
(
self
)
->
TransformerModel
:
def
load_transformer
(
self
):
"""Load transformer model
Returns:
...
...
@@ -48,8 +20,7 @@ class BaseRunner(ABC):
"""
pass
@
abstractmethod
def
load_text_encoder
(
self
)
->
Union
[
TextEncoderModel
,
List
[
TextEncoderModel
]]:
def
load_text_encoder
(
self
):
"""Load text encoder
Returns:
...
...
@@ -57,8 +28,7 @@ class BaseRunner(ABC):
"""
pass
@
abstractmethod
def
load_image_encoder
(
self
)
->
Optional
[
ImageEncoderModel
]:
def
load_image_encoder
(
self
):
"""Load image encoder
Returns:
...
...
@@ -66,8 +36,7 @@ class BaseRunner(ABC):
"""
pass
@
abstractmethod
def
load_vae
(
self
)
->
Tuple
[
VAEModel
,
VAEModel
]:
def
load_vae
(
self
):
"""Load VAE encoder and decoder
Returns:
...
...
@@ -75,8 +44,7 @@ class BaseRunner(ABC):
"""
pass
@
abstractmethod
def
run_image_encoder
(
self
,
img
:
Any
)
->
Any
:
def
run_image_encoder
(
self
,
img
):
"""Run image encoder
Args:
...
...
@@ -87,8 +55,7 @@ class BaseRunner(ABC):
"""
pass
@
abstractmethod
def
run_vae_encoder
(
self
,
img
:
Any
)
->
Tuple
[
Any
,
Dict
[
str
,
Any
]]:
def
run_vae_encoder
(
self
,
img
):
"""Run VAE encoder
Args:
...
...
@@ -99,8 +66,7 @@ class BaseRunner(ABC):
"""
pass
@
abstractmethod
def
run_text_encoder
(
self
,
prompt
:
str
,
img
:
Optional
[
Any
]
=
None
)
->
Any
:
def
run_text_encoder
(
self
,
prompt
,
img
):
"""Run text encoder
Args:
...
...
@@ -112,8 +78,7 @@ class BaseRunner(ABC):
"""
pass
@
abstractmethod
def
get_encoder_output_i2v
(
self
,
clip_encoder_out
:
Any
,
vae_encoder_out
:
Any
,
text_encoder_output
:
Any
,
img
:
Any
)
->
Dict
[
str
,
Any
]:
def
get_encoder_output_i2v
(
self
,
clip_encoder_out
,
vae_encoder_out
,
text_encoder_output
,
img
):
"""Combine encoder outputs for i2v task
Args:
...
...
@@ -127,12 +92,11 @@ class BaseRunner(ABC):
"""
pass
@
abstractmethod
def
init_scheduler
(
self
)
->
None
:
def
init_scheduler
(
self
):
"""Initialize scheduler"""
pass
def
set_target_shape
(
self
)
->
Dict
[
str
,
Any
]
:
def
set_target_shape
(
self
):
"""Set target shape
Subclasses can override this method to provide specific implementation
...
...
@@ -142,7 +106,7 @@ class BaseRunner(ABC):
"""
return
{}
def
save_video_func
(
self
,
images
:
Any
)
->
None
:
def
save_video_func
(
self
,
images
)
:
"""Save video implementation
Subclasses can override this method to customize save logic
...
...
@@ -152,7 +116,7 @@ class BaseRunner(ABC):
"""
save_videos_grid
(
images
,
self
.
config
.
get
(
"save_video_path"
,
"./output.mp4"
),
n_rows
=
1
,
fps
=
self
.
config
.
get
(
"fps"
,
8
))
def
load_vae_decoder
(
self
)
->
VAEModel
:
def
load_vae_decoder
(
self
):
"""Load VAE decoder
Default implementation: get decoder from load_vae method
...
...
@@ -164,3 +128,21 @@ class BaseRunner(ABC):
if
not
hasattr
(
self
,
"vae_decoder"
)
or
self
.
vae_decoder
is
None
:
_
,
self
.
vae_decoder
=
self
.
load_vae
()
return
self
.
vae_decoder
def
get_video_segment_num
(
self
):
self
.
video_segment_num
=
1
def
init_run
(
self
):
pass
def
init_run_segment
(
self
,
segment_idx
):
self
.
segment_idx
=
segment_idx
def
run_segment
(
self
,
total_steps
=
None
):
pass
def
end_run_segment
(
self
):
pass
def
end_run
(
self
):
pass
lightx2v/models/runners/default_runner.py
View file @
d8454a2b
...
...
@@ -3,6 +3,7 @@ import gc
import
requests
import
torch
import
torch.distributed
as
dist
import
torchvision.transforms.functional
as
TF
from
PIL
import
Image
from
loguru
import
logger
from
requests.exceptions
import
RequestException
...
...
@@ -35,8 +36,6 @@ class DefaultRunner(BaseRunner):
self
.
load_model
()
elif
self
.
config
.
get
(
"lazy_load"
,
False
):
assert
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
run_dit
=
self
.
_run_dit_local
self
.
run_vae_decoder
=
self
.
_run_vae_decoder_local
if
self
.
config
[
"task"
]
==
"i2v"
:
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_i2v
elif
self
.
config
[
"task"
]
==
"flf2v"
:
...
...
@@ -108,7 +107,7 @@ class DefaultRunner(BaseRunner):
def
set_progress_callback
(
self
,
callback
):
self
.
progress_callback
=
callback
def
run
(
self
,
total_steps
=
None
):
def
run
_segment
(
self
,
total_steps
=
None
):
if
total_steps
is
None
:
total_steps
=
self
.
model
.
scheduler
.
infer_steps
for
step_index
in
range
(
total_steps
):
...
...
@@ -130,8 +129,7 @@ class DefaultRunner(BaseRunner):
def
run_step
(
self
):
self
.
inputs
=
self
.
run_input_encoder
()
self
.
set_target_shape
()
self
.
run_dit
(
total_steps
=
1
)
self
.
run_main
(
total_steps
=
1
)
def
end_run
(
self
):
self
.
model
.
scheduler
.
clear
()
...
...
@@ -147,10 +145,15 @@ class DefaultRunner(BaseRunner):
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
read_image_input
(
self
,
img_path
):
img
=
Image
.
open
(
img_path
).
convert
(
"RGB"
)
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
return
img
@
ProfilingContext
(
"Run Encoders"
)
def
_run_input_encoder_local_i2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
img
=
Image
.
open
(
self
.
config
[
"image_path"
])
.
convert
(
"RGB"
)
img
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
clip_encoder_out
=
self
.
run_image_encoder
(
img
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
vae_encode_out
=
self
.
run_vae_encoder
(
img
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
img
)
...
...
@@ -172,8 +175,8 @@ class DefaultRunner(BaseRunner):
@
ProfilingContext
(
"Run Encoders"
)
def
_run_input_encoder_local_flf2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
first_frame
=
Image
.
open
(
self
.
config
[
"image_path"
])
.
convert
(
"RGB"
)
last_frame
=
Image
.
open
(
self
.
config
[
"last_frame_path"
])
.
convert
(
"RGB"
)
first_frame
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
last_frame
=
self
.
read_image_input
(
self
.
config
[
"last_frame_path"
])
clip_encoder_out
=
self
.
run_image_encoder
(
first_frame
,
last_frame
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
vae_encode_out
=
self
.
run_vae_encoder
(
first_frame
,
last_frame
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
first_frame
)
...
...
@@ -201,20 +204,32 @@ class DefaultRunner(BaseRunner):
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
None
,
vae_encoder_out
,
text_encoder_output
)
@
ProfilingContext
(
"Run DiT"
)
def
_run_dit_local
(
self
,
total_steps
=
None
):
def
init_run
(
self
):
self
.
set_target_shape
()
self
.
get_video_segment_num
()
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
=
self
.
load_transformer
()
self
.
init_scheduler
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
self
.
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
None
latents
,
generator
=
self
.
run
(
total_steps
)
@
ProfilingContext
(
"Run DiT"
)
def
run_main
(
self
,
total_steps
=
None
):
self
.
init_run
()
for
segment_idx
in
range
(
self
.
video_segment_num
):
# 1. default do nothing
self
.
init_run_segment
(
segment_idx
)
# 2. main inference loop
latents
,
generator
=
self
.
run_segment
(
total_steps
=
total_steps
)
# 3. vae decoder
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
,
generator
)
# 4. default do nothing
self
.
end_run_segment
()
self
.
end_run
()
return
latents
,
generator
@
ProfilingContext
(
"Run VAE Decoder"
)
def
_
run_vae_decoder
_local
(
self
,
latents
,
generator
):
def
run_vae_decoder
(
self
,
latents
,
generator
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_decoder
=
self
.
load_vae_decoder
()
images
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
...
...
@@ -240,15 +255,15 @@ class DefaultRunner(BaseRunner):
logger
.
info
(
f
"Enhanced prompt:
{
enhanced_prompt
}
"
)
return
enhanced_prompt
def
process_images_after_vae_decoder
(
self
,
images
,
save_video
=
True
):
images
=
vae_to_comfyui_image
(
images
)
def
process_images_after_vae_decoder
(
self
,
save_video
=
True
):
self
.
gen_video
=
vae_to_comfyui_image
(
self
.
gen_video
)
if
"video_frame_interpolation"
in
self
.
config
:
assert
self
.
vfi_model
is
not
None
and
self
.
config
[
"video_frame_interpolation"
].
get
(
"target_fps"
,
None
)
is
not
None
target_fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
logger
.
info
(
f
"Interpolating frames from
{
self
.
config
.
get
(
'fps'
,
16
)
}
to
{
target_fps
}
"
)
images
=
self
.
vfi_model
.
interpolate_frames
(
images
,
self
.
gen_video
=
self
.
vfi_model
.
interpolate_frames
(
self
.
gen_video
,
source_fps
=
self
.
config
.
get
(
"fps"
,
16
),
target_fps
=
target_fps
,
)
...
...
@@ -262,24 +277,21 @@ class DefaultRunner(BaseRunner):
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
"🎬 Start to save video 🎬"
)
save_to_video
(
images
,
self
.
config
.
save_video_path
,
fps
=
fps
,
method
=
"ffmpeg"
)
save_to_video
(
self
.
gen_video
,
self
.
config
.
save_video_path
,
fps
=
fps
,
method
=
"ffmpeg"
)
logger
.
info
(
f
"✅ Video saved successfully to:
{
self
.
config
.
save_video_path
}
✅"
)
return
{
"video"
:
self
.
gen_video
}
def
run_pipeline
(
self
,
save_video
=
True
):
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
self
.
inputs
=
self
.
run_input_encoder
()
self
.
set_target_shape
()
latents
,
generator
=
self
.
run_
dit
()
self
.
run_
main
()
images
=
self
.
run_vae_decoder
(
latents
,
generator
)
self
.
process_images_after_vae_decoder
(
images
,
save_video
=
save_video
)
gen_video
=
self
.
process_images_after_vae_decoder
(
save_video
=
save_video
)
del
latents
,
generator
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
# Return (images, audio) - audio is None for default runner
return
images
,
None
return
gen_video
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
d8454a2b
import
gc
import
os
import
subprocess
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
...
@@ -9,6 +8,7 @@ import numpy as np
import
torch
import
torch.distributed
as
dist
import
torchaudio
as
ta
import
torchvision.transforms.functional
as
TF
from
PIL
import
Image
from
einops
import
rearrange
from
loguru
import
logger
...
...
@@ -16,29 +16,19 @@ from torchvision.transforms import InterpolationMode
from
torchvision.transforms.functional
import
resize
from
transformers
import
AutoFeatureExtractor
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
lightx2v.models.input_encoders.hf.seko_audio.audio_adapter
import
AudioAdapter
,
rank0_load_state_dict_from_path
from
lightx2v.models.input_encoders.hf.seko_audio.audio_encoder
import
SekoAudioEncoderModel
from
lightx2v.models.networks.wan.audio_model
import
Wan22MoeAudioModel
,
WanAudioModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.runners.wan.wan_runner
import
MultiModelStruct
,
WanRunner
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
find_torch_model_path
,
save_to_video
,
vae_to_comfyui_image
@
contextmanager
def
memory_efficient_inference
():
"""Context manager for memory-efficient inference"""
try
:
yield
finally
:
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
get_optimal_patched_size_with_sp
(
patched_h
,
patched_w
,
sp_size
):
assert
sp_size
>
0
and
(
sp_size
&
(
sp_size
-
1
))
==
0
,
"sp_size must be a power of 2"
...
...
@@ -244,17 +234,91 @@ class AudioProcessor:
return
segments
class
VideoGenerator
:
"""Handles video generation for each segment"""
def
__init__
(
self
,
model
,
vae_encoder
,
vae_decoder
,
config
,
progress_callback
=
None
):
self
.
model
=
model
self
.
vae_encoder
=
vae_encoder
self
.
vae_decoder
=
vae_decoder
self
.
config
=
config
@
RUNNER_REGISTER
(
"wan2.1_audio"
)
class
WanAudioRunner
(
WanRunner
):
# type:ignore
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
_audio_processor
=
None
self
.
_video_generator
=
None
self
.
_audio_preprocess
=
None
self
.
frame_preprocessor
=
FramePreprocessor
()
self
.
progress_callback
=
progress_callback
self
.
total_segments
=
1
def
init_scheduler
(
self
):
"""Initialize consistency model scheduler"""
scheduler
=
ConsistencyModelScheduler
(
self
.
config
)
self
.
model
.
set_scheduler
(
scheduler
)
def
read_audio_input
(
self
):
"""Read audio input"""
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
self
.
_audio_processor
=
AudioProcessor
(
audio_sr
,
target_fps
)
audio_array
=
self
.
_audio_processor
.
load_audio
(
self
.
config
[
"audio_path"
])
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
audio_sr
*
target_fps
)
expected_frames
=
min
(
max
(
1
,
int
(
video_duration
*
target_fps
)),
audio_len
)
# Segment audio
audio_segments
=
self
.
_audio_processor
.
segment_audio
(
audio_array
,
expected_frames
,
self
.
config
.
get
(
"target_video_length"
,
81
))
return
audio_segments
,
expected_frames
def
read_image_input
(
self
,
img_path
):
ref_img
=
Image
.
open
(
img_path
).
convert
(
"RGB"
)
ref_img
=
TF
.
to_tensor
(
ref_img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
ref_img
,
h
,
w
=
adaptive_resize
(
ref_img
)
patched_h
=
h
//
self
.
config
.
vae_stride
[
1
]
//
self
.
config
.
patch_size
[
1
]
patched_w
=
w
//
self
.
config
.
vae_stride
[
2
]
//
self
.
config
.
patch_size
[
2
]
patched_h
,
patched_w
=
get_optimal_patched_size_with_sp
(
patched_h
,
patched_w
,
1
)
self
.
config
.
lat_h
=
patched_h
*
self
.
config
.
patch_size
[
1
]
self
.
config
.
lat_w
=
patched_w
*
self
.
config
.
patch_size
[
2
]
self
.
config
.
tgt_h
=
self
.
config
.
lat_h
*
self
.
config
.
vae_stride
[
1
]
self
.
config
.
tgt_w
=
self
.
config
.
lat_w
*
self
.
config
.
vae_stride
[
2
]
logger
.
info
(
f
"[wan_audio] tgt_h:
{
self
.
config
.
tgt_h
}
, tgt_w:
{
self
.
config
.
tgt_w
}
, lat_h:
{
self
.
config
.
lat_h
}
, lat_w:
{
self
.
config
.
lat_w
}
"
)
ref_img
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
),
mode
=
"bicubic"
)
return
ref_img
def
run_image_encoder
(
self
,
first_frame
,
last_frame
=
None
):
clip_encoder_out
=
self
.
image_encoder
.
visual
([
first_frame
]).
squeeze
(
0
).
to
(
GET_DTYPE
())
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
return
clip_encoder_out
def
run_vae_encoder
(
self
,
img
):
img
=
rearrange
(
img
,
"1 C H W -> 1 C 1 H W"
)
vae_encoder_out
=
self
.
vae_encoder
.
encode
(
img
.
to
(
torch
.
float
))
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
vae_encoder_out
=
vae_encoder_out
.
unsqueeze
(
0
).
to
(
GET_DTYPE
())
else
:
if
isinstance
(
vae_encoder_out
,
list
):
vae_encoder_out
=
torch
.
stack
(
vae_encoder_out
,
dim
=
0
).
to
(
GET_DTYPE
())
return
vae_encoder_out
@
ProfilingContext
(
"Run Encoders"
)
def
_run_input_encoder_local_r2v_audio
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
img
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
clip_encoder_out
=
self
.
run_image_encoder
(
img
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
vae_encode_out
=
self
.
run_vae_encoder
(
img
)
audio_segments
,
expected_frames
=
self
.
read_audio_input
()
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
return
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
{
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encoder_out"
:
vae_encode_out
,
},
"audio_segments"
:
audio_segments
,
"expected_frames"
:
expected_frames
,
}
def
prepare_prev_latents
(
self
,
prev_video
:
Optional
[
torch
.
Tensor
],
prev_frame_length
:
int
)
->
Optional
[
Dict
[
str
,
torch
.
Tensor
]]:
"""Prepare previous latents for conditioning"""
...
...
@@ -295,31 +359,6 @@ class VideoGenerator:
return
{
"prev_latents"
:
prev_latents
,
"prev_mask"
:
prev_mask
}
def
_wan22_masks_like
(
self
,
tensor
,
zero
=
False
,
generator
=
None
,
p
=
0.2
,
prev_length
=
1
):
assert
isinstance
(
tensor
,
list
)
out1
=
[
torch
.
ones
(
u
.
shape
,
dtype
=
u
.
dtype
,
device
=
u
.
device
)
for
u
in
tensor
]
out2
=
[
torch
.
ones
(
u
.
shape
,
dtype
=
u
.
dtype
,
device
=
u
.
device
)
for
u
in
tensor
]
if
prev_length
==
0
:
return
out1
,
out2
if
zero
:
if
generator
is
not
None
:
for
u
,
v
in
zip
(
out1
,
out2
):
random_num
=
torch
.
rand
(
1
,
generator
=
generator
,
device
=
generator
.
device
).
item
()
if
random_num
<
p
:
u
[:,
:
prev_length
]
=
torch
.
normal
(
mean
=-
3.5
,
std
=
0.5
,
size
=
(
1
,),
device
=
u
.
device
,
generator
=
generator
).
expand_as
(
u
[:,
:
prev_length
]).
exp
()
v
[:,
:
prev_length
]
=
torch
.
zeros_like
(
v
[:,
:
prev_length
])
else
:
u
[:,
:
prev_length
]
=
u
[:,
:
prev_length
]
v
[:,
:
prev_length
]
=
v
[:,
:
prev_length
]
else
:
for
u
,
v
in
zip
(
out1
,
out2
):
u
[:,
:
prev_length
]
=
torch
.
zeros_like
(
u
[:,
:
prev_length
])
v
[:,
:
prev_length
]
=
torch
.
zeros_like
(
v
[:,
:
prev_length
])
return
out1
,
out2
def
_wan_mask_rearrange
(
self
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Rearrange mask for WAN model"""
if
mask
.
ndim
==
3
:
...
...
@@ -332,250 +371,99 @@ class VideoGenerator:
mask
=
mask
.
view
(
mask
.
shape
[
1
]
//
4
,
4
,
h
,
w
)
return
mask
.
transpose
(
0
,
1
)
@
torch
.
no_grad
()
def
generate_segment
(
self
,
inputs
,
audio_features
,
prev_video
=
None
,
prev_frame_length
=
5
,
segment_idx
=
0
,
total_steps
=
None
):
"""Generate video segment"""
# Update inputs with audio features
inputs
[
"audio_encoder_output"
]
=
audio_features
# Reset scheduler for non-first segments
if
segment_idx
>
0
:
self
.
model
.
scheduler
.
reset
()
inputs
[
"previmg_encoder_output"
]
=
self
.
prepare_prev_latents
(
prev_video
,
prev_frame_length
)
def
get_video_segment_num
(
self
):
self
.
video_segment_num
=
len
(
self
.
inputs
[
"audio_segments"
])
# Run inference loop
if
total_steps
is
None
:
total_steps
=
self
.
model
.
scheduler
.
infer_steps
for
step_index
in
range
(
total_steps
):
logger
.
info
(
f
"==> Segment
{
segment_idx
}
, Step
{
step_index
}
/
{
total_steps
}
"
)
def
init_run
(
self
):
super
().
init_run
()
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
gen_video_list
=
[]
self
.
cut_audio_list
=
[]
self
.
prev_video
=
None
with
ProfilingContext4Debug
(
"🚀 infer_main"
):
self
.
model
.
infer
(
inputs
)
def
init_run_segment
(
self
,
segment_idx
):
self
.
segment_idx
=
segment_idx
with
ProfilingContext4Debug
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
self
.
model
.
scheduler
.
latents
=
(
1.0
-
prev_mask
[
0
])
*
prev_latents
+
prev_mask
[
0
]
*
self
.
model
.
scheduler
.
latents
self
.
segment
=
self
.
inputs
[
"audio_segments"
][
segment_idx
]
if
self
.
progress_callback
:
segment_progress
=
(
segment_idx
*
total_steps
+
step_index
+
1
)
/
(
self
.
total_segments
*
total_steps
)
self
.
progress_callback
(
int
(
segment_progress
*
100
),
100
)
# Decode latents
latents
=
self
.
model
.
scheduler
.
latents
generator
=
self
.
model
.
scheduler
.
generator
with
ProfilingContext
(
"Run VAE Decoder"
):
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
return
gen_video
@
RUNNER_REGISTER
(
"wan2.1_audio"
)
class
WanAudioRunner
(
WanRunner
):
# type:ignore
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
_audio_adapter_pipe
=
None
self
.
_audio_processor
=
None
self
.
_video_generator
=
None
self
.
_audio_preprocess
=
None
def
initialize
(
self
):
"""Initialize all models once for multiple runs"""
self
.
config
.
seed
=
self
.
config
.
seed
+
segment_idx
torch
.
manual_seed
(
self
.
config
.
seed
)
logger
.
info
(
f
"Processing segment
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
, seed:
{
self
.
config
.
seed
}
"
)
# Initialize audio processor
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
self
.
_audio_processor
=
AudioProcessor
(
audio_sr
,
target_fps
)
audio_features
=
self
.
audio_encoder
.
infer
(
self
.
segment
.
audio_array
).
to
(
self
.
model
.
device
)
audio_features
=
self
.
audio_adapter
.
forward_audio_proj
(
audio_features
,
self
.
model
.
scheduler
.
latents
.
shape
[
1
])
# Initialize scheduler
self
.
init_scheduler
()
self
.
inputs
[
"audio_encoder_output"
]
=
audio_features
def
init_scheduler
(
self
):
"""Initialize consistency model scheduler"""
scheduler
=
ConsistencyModelScheduler
(
self
.
config
)
self
.
model
.
set_scheduler
(
scheduler
)
# Reset scheduler for non-first segments
if
segment_idx
>
0
:
self
.
model
.
scheduler
.
reset
()
def
load_audio_adapter_lazy
(
self
):
"""Lazy load audio adapter when needed"""
if
self
.
_audio_adapter_pipe
is
not
None
:
return
self
.
_audio_adapter_pipe
self
.
inputs
[
"previmg_encoder_output"
]
=
self
.
prepare_prev_latents
(
self
.
prev_video
,
prev_frame_length
=
5
)
# Audio adapter
audio_adapter_path
=
self
.
config
[
"model_path"
]
+
"/audio_adapter.safetensors"
audio_adapter
=
AudioAdapter
.
from_transformer
(
self
.
model
,
audio_feature_dim
=
1024
,
interval
=
1
,
time_freq_dim
=
256
,
projection_transformer_layers
=
4
,
)
def
end_run_segment
(
self
):
self
.
gen_video
=
torch
.
clamp
(
self
.
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
# Audio encoder
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
if
cpu_offload
:
device
=
torch
.
device
(
"cpu"
)
else
:
device
=
torch
.
device
(
"cuda"
)
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
# Extract relevant frames
start_frame
=
0
if
self
.
segment_idx
==
0
else
5
start_audio_frame
=
0
if
self
.
segment_idx
==
0
else
int
(
6
*
self
.
_audio_processor
.
audio_sr
/
self
.
config
.
get
(
"target_fps"
,
16
))
if
self
.
model
.
transformer_infer
.
seq_p_group
is
not
None
:
seq_p_group
=
self
.
model
.
transformer_infer
.
seq_p_group
if
self
.
segment
.
is_last
and
self
.
segment
.
useful_length
:
end_frame
=
self
.
segment
.
end_frame
-
self
.
segment
.
start_frame
self
.
gen_video_list
.
append
(
self
.
gen_video
[:,
:,
start_frame
:
end_frame
].
cpu
())
self
.
cut_audio_list
.
append
(
self
.
segment
.
audio_array
[
start_audio_frame
:
self
.
segment
.
useful_length
])
elif
self
.
segment
.
useful_length
and
self
.
inputs
[
"expected_frames"
]
<
self
.
config
.
get
(
"target_video_length"
,
81
):
self
.
gen_video_list
.
append
(
self
.
gen_video
[:,
:,
start_frame
:
self
.
inputs
[
"expected_frames"
]].
cpu
())
self
.
cut_audio_list
.
append
(
self
.
segment
.
audio_array
[
start_audio_frame
:
self
.
segment
.
useful_length
])
else
:
seq_p_group
=
None
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adapter
,
audio_adapter_path
,
strict
=
False
)
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
GET_DTYPE
(),
device
=
device
,
weight
=
1.0
,
cpu_offload
=
cpu_offload
,
seq_p_group
=
seq_p_group
)
return
self
.
_audio_adapter_pipe
def
prepare_inputs
(
self
):
"""Prepare inputs for the model"""
image_encoder_output
=
None
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
with
ProfilingContext
(
"Run Img Encoder"
):
vae_encoder_out
,
clip_encoder_out
=
self
.
run_image_encoder
(
self
.
config
,
self
.
vae_encoder
)
image_encoder_output
=
{
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encoder_out"
:
vae_encoder_out
,
}
self
.
gen_video_list
.
append
(
self
.
gen_video
[:,
:,
start_frame
:].
cpu
())
self
.
cut_audio_list
.
append
(
self
.
segment
.
audio_array
[
start_audio_frame
:])
# Update prev_video for next iteration
self
.
prev_video
=
self
.
gen_video
# Clean up GPU memory after each segment
del
self
.
gen_video
torch
.
cuda
.
empty_cache
()
def
process_images_after_vae_decoder
(
self
,
save_video
=
True
):
# Merge results
gen_lvideo
=
torch
.
cat
(
self
.
gen_video_list
,
dim
=
2
).
float
()
merge_audio
=
np
.
concatenate
(
self
.
cut_audio_list
,
axis
=
0
).
astype
(
np
.
float32
)
comfyui_images
=
vae_to_comfyui_image
(
gen_lvideo
)
# Apply frame interpolation if configured
if
"video_frame_interpolation"
in
self
.
config
and
self
.
vfi_model
is
not
None
:
target_fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
logger
.
info
(
f
"Interpolating frames from
{
self
.
config
.
get
(
'fps'
,
16
)
}
to
{
target_fps
}
"
)
comfyui_images
=
self
.
vfi_model
.
interpolate_frames
(
comfyui_images
,
source_fps
=
self
.
config
.
get
(
"fps"
,
16
),
target_fps
=
target_fps
,
)
with
ProfilingContext
(
"Run Text Encoder"
):
img
=
Image
.
open
(
self
.
config
[
"image_path"
]).
convert
(
"RGB"
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
config
[
"prompt"
],
img
)
if
save_video
:
if
"video_frame_interpolation"
in
self
.
config
and
self
.
config
[
"video_frame_interpolation"
].
get
(
"target_fps"
):
fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
else
:
fps
=
self
.
config
.
get
(
"fps"
,
16
)
self
.
set_target_shape
()
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
"🎬 Start to save video 🎬"
)
return
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
,
"audio_adapter_pipe"
:
self
.
load_audio_adapter_lazy
()}
self
.
_save_video_with_audio
(
comfyui_images
,
merge_audio
,
fps
)
logger
.
info
(
f
"✅ Video saved successfully to:
{
self
.
config
.
save_video_path
}
✅"
)
def
run_pipeline
(
self
,
save_video
=
True
):
"""Optimized pipeline with modular components"""
# Convert audio to ComfyUI format
audio_waveform
=
torch
.
from_numpy
(
merge_audio
).
unsqueeze
(
0
).
unsqueeze
(
0
)
comfyui_audio
=
{
"waveform"
:
audio_waveform
,
"sample_rate"
:
self
.
_audio_processor
.
audio_sr
}
try
:
self
.
initialize
()
assert
self
.
_audio_processor
is
not
None
assert
self
.
_audio_preprocess
is
not
None
self
.
_video_generator
=
VideoGenerator
(
self
.
model
,
self
.
vae_encoder
,
self
.
vae_decoder
,
self
.
config
,
self
.
progress_callback
)
with
memory_efficient_inference
():
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
self
.
inputs
=
self
.
prepare_inputs
()
# Re-initialize scheduler after image encoding sets correct dimensions
self
.
init_scheduler
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
# Re-create video generator with updated model/scheduler
self
.
_video_generator
=
VideoGenerator
(
self
.
model
,
self
.
vae_encoder
,
self
.
vae_decoder
,
self
.
config
,
self
.
progress_callback
)
# Process audio
audio_array
=
self
.
_audio_processor
.
load_audio
(
self
.
config
[
"audio_path"
])
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
max_num_frames
=
self
.
config
.
get
(
"target_video_length"
,
81
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
self
.
_audio_processor
.
audio_sr
*
target_fps
)
expected_frames
=
min
(
max
(
1
,
int
(
video_duration
*
target_fps
)),
audio_len
)
# Segment audio
audio_segments
=
self
.
_audio_processor
.
segment_audio
(
audio_array
,
expected_frames
,
max_num_frames
)
self
.
_video_generator
.
total_segments
=
len
(
audio_segments
)
# Generate video segments
gen_video_list
=
[]
cut_audio_list
=
[]
prev_video
=
None
for
idx
,
segment
in
enumerate
(
audio_segments
):
self
.
config
.
seed
=
self
.
config
.
seed
+
idx
torch
.
manual_seed
(
self
.
config
.
seed
)
logger
.
info
(
f
"Processing segment
{
idx
+
1
}
/
{
len
(
audio_segments
)
}
, seed:
{
self
.
config
.
seed
}
"
)
# Process audio features
audio_features
=
self
.
_audio_preprocess
(
segment
.
audio_array
,
sampling_rate
=
self
.
_audio_processor
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
).
to
(
self
.
model
.
device
)
# Generate video segment
with
memory_efficient_inference
():
gen_video
=
self
.
_video_generator
.
generate_segment
(
self
.
inputs
.
copy
(),
# Copy to avoid modifying original
audio_features
,
prev_video
=
prev_video
,
prev_frame_length
=
5
,
segment_idx
=
idx
,
)
# Extract relevant frames
start_frame
=
0
if
idx
==
0
else
5
start_audio_frame
=
0
if
idx
==
0
else
int
(
6
*
self
.
_audio_processor
.
audio_sr
/
target_fps
)
if
segment
.
is_last
and
segment
.
useful_length
:
end_frame
=
segment
.
end_frame
-
segment
.
start_frame
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
end_frame
].
cpu
())
cut_audio_list
.
append
(
segment
.
audio_array
[
start_audio_frame
:
segment
.
useful_length
])
elif
segment
.
useful_length
and
expected_frames
<
max_num_frames
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
expected_frames
].
cpu
())
cut_audio_list
.
append
(
segment
.
audio_array
[
start_audio_frame
:
segment
.
useful_length
])
else
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:].
cpu
())
cut_audio_list
.
append
(
segment
.
audio_array
[
start_audio_frame
:])
# Update prev_video for next iteration
prev_video
=
gen_video
# Clean up GPU memory after each segment
del
gen_video
torch
.
cuda
.
empty_cache
()
# Merge results
with
memory_efficient_inference
():
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
merge_audio
=
np
.
concatenate
(
cut_audio_list
,
axis
=
0
).
astype
(
np
.
float32
)
comfyui_images
=
vae_to_comfyui_image
(
gen_lvideo
)
# Apply frame interpolation if configured
if
"video_frame_interpolation"
in
self
.
config
and
self
.
vfi_model
is
not
None
:
interpolation_target_fps
=
self
.
config
[
"video_frame_interpolation"
][
"target_fps"
]
logger
.
info
(
f
"Interpolating frames from
{
target_fps
}
to
{
interpolation_target_fps
}
"
)
comfyui_images
=
self
.
vfi_model
.
interpolate_frames
(
comfyui_images
,
source_fps
=
target_fps
,
target_fps
=
interpolation_target_fps
,
)
target_fps
=
interpolation_target_fps
# Convert audio to ComfyUI format
audio_waveform
=
torch
.
from_numpy
(
merge_audio
).
unsqueeze
(
0
).
unsqueeze
(
0
)
comfyui_audio
=
{
"waveform"
:
audio_waveform
,
"sample_rate"
:
self
.
_audio_processor
.
audio_sr
}
# Save video if requested
if
(
self
.
config
.
get
(
"device_mesh"
)
is
not
None
and
dist
.
get_rank
()
==
0
)
or
self
.
config
.
get
(
"device_mesh"
)
is
None
:
if
save_video
and
self
.
config
.
get
(
"save_video_path"
,
None
):
self
.
_save_video_with_audio
(
comfyui_images
,
merge_audio
,
target_fps
)
# Final cleanup
self
.
end_run
()
return
comfyui_images
,
comfyui_audio
return
{
"video"
:
comfyui_images
,
"audio"
:
comfyui_audio
}
finally
:
self
.
_video_generator
=
None
gc
.
collect
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
empty_cache
()
def
init_modules
(
self
):
super
().
init_modules
()
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_r2v_audio
def
_save_video_with_audio
(
self
,
images
,
audio_array
,
fps
):
"""Save video with audio"""
...
...
@@ -620,63 +508,43 @@ class WanAudioRunner(WanRunner): # type:ignore
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
# XXX: trick
self
.
_audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
return
base_model
def
run_image_encoder
(
self
,
config
,
vae_model
):
"""Run image encoder"""
ref_img
=
Image
.
open
(
config
.
image_path
)
ref_img
=
(
np
.
array
(
ref_img
).
astype
(
np
.
float32
)
-
127.5
)
/
127.5
ref_img
=
torch
.
from_numpy
(
ref_img
).
cuda
()
ref_img
=
rearrange
(
ref_img
,
"H W C -> 1 C H W"
)
ref_img
=
ref_img
[:,
:
3
]
adaptive
=
config
.
get
(
"adaptive_resize"
,
False
)
if
adaptive
:
# Use adaptive_resize to modify aspect ratio
ref_img
,
h
,
w
=
adaptive_resize
(
ref_img
)
patched_h
=
h
//
self
.
config
.
vae_stride
[
1
]
//
self
.
config
.
patch_size
[
1
]
patched_w
=
w
//
self
.
config
.
vae_stride
[
2
]
//
self
.
config
.
patch_size
[
2
]
else
:
h
,
w
=
ref_img
.
shape
[
2
:]
aspect_ratio
=
h
/
w
max_area
=
config
.
target_height
*
config
.
target_width
patched_h
=
round
(
np
.
sqrt
(
max_area
*
aspect_ratio
)
//
config
.
vae_stride
[
1
]
//
config
.
patch_size
[
1
])
patched_w
=
round
(
np
.
sqrt
(
max_area
/
aspect_ratio
)
//
config
.
vae_stride
[
2
]
//
config
.
patch_size
[
2
])
patched_h
,
patched_w
=
get_optimal_patched_size_with_sp
(
patched_h
,
patched_w
,
1
)
config
.
lat_h
=
patched_h
*
self
.
config
.
patch_size
[
1
]
config
.
lat_w
=
patched_w
*
self
.
config
.
patch_size
[
2
]
def
load_audio_encoder
(
self
):
model
=
SekoAudioEncoderModel
(
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"audio_encoder"
),
self
.
config
[
"audio_sr"
])
return
model
config
.
tgt_h
=
config
.
lat_h
*
self
.
config
.
vae_stride
[
1
]
config
.
tgt_w
=
config
.
lat_w
*
self
.
config
.
vae_stride
[
2
]
logger
.
info
(
f
"[wan_audio] adaptive_resize:
{
adaptive
}
, tgt_h:
{
config
.
tgt_h
}
, tgt_w:
{
config
.
tgt_w
}
, lat_h:
{
config
.
lat_h
}
, lat_w:
{
config
.
lat_w
}
"
)
cond_frms
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
config
.
tgt_h
,
config
.
tgt_w
),
mode
=
"bicubic"
)
# clip encoder
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
]).
squeeze
(
0
).
to
(
GET_DTYPE
())
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
# vae encode
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
vae_encoder_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
))
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
vae_encoder_out
=
vae_encoder_out
.
unsqueeze
(
0
).
to
(
GET_DTYPE
())
def
load_audio_adapter
(
self
):
audio_adapter
=
AudioAdapter
(
attention_head_dim
=
5120
//
self
.
config
[
"num_heads"
],
num_attention_heads
=
self
.
config
[
"num_heads"
],
base_num_layers
=
self
.
config
[
"num_layers"
],
interval
=
1
,
audio_feature_dim
=
1024
,
time_freq_dim
=
256
,
projection_transformer_layers
=
4
,
mlp_dims
=
(
1024
,
1024
,
32
*
1024
),
quantized
=
self
.
config
.
get
(
"adapter_quantized"
,
False
),
quant_scheme
=
self
.
config
.
get
(
"adapter_quant_scheme"
,
None
),
)
if
self
.
config
.
get
(
"adapter_quantized"
,
False
):
if
self
.
config
.
get
(
"adapter_quant_scheme"
,
None
)
==
"fp8"
:
model_name
=
"audio_adapter_fp8.safetensors"
elif
self
.
config
.
get
(
"adapter_quant_scheme"
,
None
)
==
"int8"
:
model_name
=
"audio_adapter_int8.safetensors"
else
:
raise
ValueError
(
f
"Unsupported quant_scheme:
{
self
.
config
.
get
(
'adapter_quant_scheme'
,
None
)
}
"
)
else
:
if
isinstance
(
vae_encoder_out
,
list
):
vae_encoder_out
=
torch
.
stack
(
vae_encoder_out
,
dim
=
0
).
to
(
GET_DTYPE
())
model_name
=
"audio_adapter.safetensors"
rank0_load_state_dict_from_path
(
audio_adapter
,
os
.
path
.
join
(
self
.
config
[
"model_path"
],
model_name
),
strict
=
False
)
return
audio_adapter
.
to
(
dtype
=
GET_DTYPE
())
return
vae_encoder_out
,
clip_encoder_out
@
ProfilingContext
(
"Load models"
)
def
load_model
(
self
):
super
().
load_model
()
self
.
audio_encoder
=
self
.
load_audio_encoder
()
self
.
audio_adapter
=
self
.
load_audio_adapter
()
self
.
model
.
set_audio_adapter
(
self
.
audio_adapter
)
def
set_target_shape
(
self
):
"""Set target shape for generation"""
...
...
@@ -701,62 +569,6 @@ class WanAudioRunner(WanRunner): # type:ignore
ret
[
"target_shape"
]
=
self
.
config
.
target_shape
return
ret
def
run_step
(
self
):
"""Optimized pipeline with modular components"""
self
.
initialize
()
assert
self
.
_audio_processor
is
not
None
assert
self
.
_audio_preprocess
is
not
None
self
.
_video_generator
=
VideoGenerator
(
self
.
model
,
self
.
vae_encoder
,
self
.
vae_decoder
,
self
.
config
,
self
.
progress_callback
)
with
memory_efficient_inference
():
if
self
.
config
[
"use_prompt_enhancer"
]:
self
.
config
[
"prompt_enhanced"
]
=
self
.
post_prompt_enhancer
()
self
.
inputs
=
self
.
prepare_inputs
()
# Re-initialize scheduler after image encoding sets correct dimensions
self
.
init_scheduler
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
# Re-create video generator with updated model/scheduler
self
.
_video_generator
=
VideoGenerator
(
self
.
model
,
self
.
vae_encoder
,
self
.
vae_decoder
,
self
.
config
,
self
.
progress_callback
)
# Process audio
audio_array
=
self
.
_audio_processor
.
load_audio
(
self
.
config
[
"audio_path"
])
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
max_num_frames
=
self
.
config
.
get
(
"target_video_length"
,
81
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
self
.
_audio_processor
.
audio_sr
*
target_fps
)
expected_frames
=
min
(
max
(
1
,
int
(
video_duration
*
target_fps
)),
audio_len
)
# Segment audio
audio_segments
=
self
.
_audio_processor
.
segment_audio
(
audio_array
,
expected_frames
,
max_num_frames
)
self
.
_video_generator
.
total_segments
=
len
(
audio_segments
)
# Generate video segments
prev_video
=
None
torch
.
manual_seed
(
self
.
config
.
seed
)
# Process audio features
audio_features
=
self
.
_audio_preprocess
(
audio_segments
[
0
].
audio_array
,
sampling_rate
=
self
.
_audio_processor
.
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
).
to
(
self
.
model
.
device
)
# Generate video segment
with
memory_efficient_inference
():
self
.
_video_generator
.
generate_segment
(
self
.
inputs
.
copy
(),
# Copy to avoid modifying original
audio_features
,
prev_video
=
prev_video
,
prev_frame_length
=
5
,
segment_idx
=
0
,
total_steps
=
1
,
)
# Final cleanup
self
.
end_run
()
@
RUNNER_REGISTER
(
"wan2.2_audio"
)
class
Wan22AudioRunner
(
WanAudioRunner
):
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
d8454a2b
...
...
@@ -225,12 +225,10 @@ class WanRunner(DefaultRunner):
def
run_image_encoder
(
self
,
first_frame
,
last_frame
=
None
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
image_encoder
=
self
.
load_image_encoder
()
first_frame
=
TF
.
to_tensor
(
first_frame
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
if
last_frame
is
None
:
clip_encoder_out
=
self
.
image_encoder
.
visual
([
first_frame
[
None
,
:,
:,
:]
]).
squeeze
(
0
).
to
(
GET_DTYPE
())
clip_encoder_out
=
self
.
image_encoder
.
visual
([
first_frame
]).
squeeze
(
0
).
to
(
GET_DTYPE
())
else
:
last_frame
=
TF
.
to_tensor
(
last_frame
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
clip_encoder_out
=
self
.
image_encoder
.
visual
([
first_frame
[:,
None
,
:,
:].
transpose
(
0
,
1
),
last_frame
[:,
None
,
:,
:].
transpose
(
0
,
1
)]).
squeeze
(
0
).
to
(
GET_DTYPE
())
clip_encoder_out
=
self
.
image_encoder
.
visual
([
first_frame
,
last_frame
]).
squeeze
(
0
).
to
(
GET_DTYPE
())
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
image_encoder
torch
.
cuda
.
empty_cache
()
...
...
@@ -238,9 +236,7 @@ class WanRunner(DefaultRunner):
return
clip_encoder_out
def
run_vae_encoder
(
self
,
first_frame
,
last_frame
=
None
):
first_frame_size
=
first_frame
.
size
first_frame
=
TF
.
to_tensor
(
first_frame
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
h
,
w
=
first_frame
.
shape
[
1
:]
h
,
w
=
first_frame
.
shape
[
2
:]
aspect_ratio
=
h
/
w
max_area
=
self
.
config
.
target_height
*
self
.
config
.
target_width
lat_h
=
round
(
np
.
sqrt
(
max_area
*
aspect_ratio
)
//
self
.
config
.
vae_stride
[
1
]
//
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
1
])
...
...
@@ -260,8 +256,8 @@ class WanRunner(DefaultRunner):
return
vae_encode_out_list
else
:
if
last_frame
is
not
None
:
la
st_frame_size
=
la
st_frame
.
s
ize
last_frame
=
TF
.
to_tensor
(
last_frame
)
.
s
ub_
(
0.5
).
div_
(
0.5
).
cuda
()
fir
st_frame_size
=
fir
st_frame
.
s
hape
[
2
:]
last_frame
_size
=
last_frame
.
s
hape
[
2
:]
if
first_frame_size
!=
last_frame_size
:
last_frame_resize_ratio
=
max
(
first_frame_size
[
0
]
/
last_frame_size
[
0
],
first_frame_size
[
1
]
/
last_frame_size
[
1
])
last_frame_size
=
[
...
...
@@ -298,16 +294,16 @@ class WanRunner(DefaultRunner):
if
last_frame
is
not
None
:
vae_input
=
torch
.
concat
(
[
torch
.
nn
.
functional
.
interpolate
(
first_frame
[
None
]
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
nn
.
functional
.
interpolate
(
first_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
zeros
(
3
,
self
.
config
.
target_video_length
-
2
,
h
,
w
),
torch
.
nn
.
functional
.
interpolate
(
last_frame
[
None
]
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
nn
.
functional
.
interpolate
(
last_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
],
dim
=
1
,
).
cuda
()
else
:
vae_input
=
torch
.
concat
(
[
torch
.
nn
.
functional
.
interpolate
(
first_frame
[
None
]
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
nn
.
functional
.
interpolate
(
first_frame
.
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
),
torch
.
zeros
(
3
,
self
.
config
.
target_video_length
-
1
,
h
,
w
),
],
dim
=
1
,
...
...
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
d8454a2b
import
gc
import
math
import
numpy
as
np
import
torch
from
loguru
import
logger
from
lightx2v.models.schedulers.scheduler
import
Base
Scheduler
from
lightx2v.models.schedulers.
wan.
scheduler
import
Wan
Scheduler
from
lightx2v.utils.envs
import
*
def
unsqueeze_to_ndim
(
in_tensor
,
tgt_n_dim
):
if
in_tensor
.
ndim
>
tgt_n_dim
:
warnings
.
warn
(
f
"the given tensor of shape
{
in_tensor
.
shape
}
is expected to unsqueeze to
{
tgt_n_dim
}
, the original tensor will be returned"
)
return
in_tensor
if
in_tensor
.
ndim
<
tgt_n_dim
:
in_tensor
=
in_tensor
[(...,)
+
(
None
,)
*
(
tgt_n_dim
-
in_tensor
.
ndim
)]
return
in_tensor
class
EulerSchedulerTimestepFix
(
BaseScheduler
):
def
__init__
(
self
,
config
,
**
kwargs
):
# super().__init__(**kwargs)
self
.
init_noise_sigma
=
1.0
self
.
config
=
config
self
.
latents
=
None
self
.
device
=
torch
.
device
(
"cuda"
)
self
.
infer_steps
=
self
.
config
.
infer_steps
self
.
target_video_length
=
self
.
config
.
target_video_length
self
.
sample_shift
=
self
.
config
.
sample_shift
self
.
num_train_timesteps
=
1000
self
.
step_index
=
None
class
ConsistencyModelScheduler
(
WanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
...
...
@@ -37,12 +19,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
if
self
.
config
.
task
in
[
"t2v"
]:
self
.
seq_len
=
math
.
ceil
((
self
.
config
.
target_shape
[
2
]
*
self
.
config
.
target_shape
[
3
])
/
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
*
self
.
config
.
target_shape
[
1
])
elif
self
.
config
.
task
in
[
"i2v"
]:
self
.
seq_len
=
((
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
timesteps
=
np
.
linspace
(
self
.
num_train_timesteps
,
0
,
self
.
infer_steps
+
1
,
dtype
=
np
.
float32
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
self
.
device
)
...
...
@@ -53,29 +29,13 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self
.
timesteps
=
self
.
sigmas
*
self
.
num_train_timesteps
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
self
.
config
.
seed
)
self
.
latents
=
(
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
target_shape
[
2
],
target_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
*
self
.
init_noise_sigma
)
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
sigma
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
sigma_next
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
+
1
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
x_t_next
=
sample
+
(
sigma_next
-
sigma
)
*
model_output
sigma
=
self
.
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
sigma_next
=
self
.
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
+
1
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
x0
=
sample
-
model_output
*
sigma
x_t_next
=
x0
*
(
1
-
sigma_next
)
+
sigma_next
*
torch
.
randn
(
x0
.
shape
,
dtype
=
x0
.
dtype
,
device
=
x0
.
device
,
generator
=
self
.
generator
)
self
.
latents
=
x_t_next
def
reset
(
self
):
...
...
@@ -83,13 +43,10 @@ class EulerSchedulerTimestepFix(BaseScheduler):
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
class
ConsistencyModelScheduler
(
EulerSchedulerTimestepFix
):
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
sigma
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
sigma_next
=
unsqueeze_to_ndim
(
self
.
sigmas
[
self
.
step_index
+
1
],
sample
.
ndim
).
to
(
sample
.
device
,
sample
.
dtype
)
x0
=
sample
-
model_output
*
sigma
x_t_next
=
x0
*
(
1
-
sigma_next
)
+
sigma_next
*
torch
.
randn
(
x0
.
shape
,
dtype
=
x0
.
dtype
,
device
=
x0
.
device
,
generator
=
self
.
generator
)
self
.
latents
=
x_t_next
def
unsqueeze_to_ndim
(
self
,
in_tensor
,
tgt_n_dim
):
if
in_tensor
.
ndim
>
tgt_n_dim
:
logger
.
warning
(
f
"the given tensor of shape
{
in_tensor
.
shape
}
is expected to unsqueeze to
{
tgt_n_dim
}
, the original tensor will be returned"
)
return
in_tensor
if
in_tensor
.
ndim
<
tgt_n_dim
:
in_tensor
=
in_tensor
[(...,)
+
(
None
,)
*
(
tgt_n_dim
-
in_tensor
.
ndim
)]
return
in_tensor
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
View file @
d8454a2b
...
...
@@ -20,6 +20,7 @@ class WanScheduler4ChangingResolution:
assert
len
(
config
[
"resolution_rate"
])
==
len
(
config
[
"changing_resolution_steps"
])
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
self
.
config
.
seed
)
self
.
latents_list
=
[]
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
self
.
latents_list
.
append
(
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
d8454a2b
...
...
@@ -26,8 +26,6 @@ class WanScheduler(BaseScheduler):
def
prepare
(
self
,
image_encoder_output
=
None
):
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
self
.
vae_encoder_out
=
image_encoder_output
[
"vae_encoder_out"
]
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
)
self
.
generator
.
manual_seed
(
self
.
config
.
seed
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
...
...
@@ -51,6 +49,7 @@ class WanScheduler(BaseScheduler):
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
self
.
config
.
seed
)
self
.
latents
=
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
...
...
tools/convert/quant_adapter.py
0 → 100644
View file @
d8454a2b
import
safetensors
import
torch
from
safetensors.torch
import
save_file
from
lightx2v.utils.quant_utils
import
FloatQuantizer
model_path
=
"/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P/audio_adapter.safetensors"
state_dict
=
{}
with
safetensors
.
safe_open
(
model_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
key
in
f
.
keys
():
state_dict
[
key
]
=
f
.
get_tensor
(
key
)
new_state_dict
=
{}
new_model_path
=
"/data/nvme0/models/Wan2.1-R2V721-Audio-14B-720P/audio_adapter_fp8.safetensors"
for
key
in
state_dict
.
keys
():
if
key
.
startswith
(
"ca"
)
and
".to"
in
key
and
"weight"
in
key
and
"to_kv"
not
in
key
:
print
(
key
,
state_dict
[
key
].
dtype
)
weight
=
state_dict
[
key
].
to
(
torch
.
float32
).
cuda
()
w_quantizer
=
FloatQuantizer
(
"e4m3"
,
True
,
"per_channel"
)
weight
,
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
weight
)
weight
=
weight
.
to
(
torch
.
float8_e4m3fn
)
weight_scale
=
weight_scale
.
to
(
torch
.
float32
)
new_state_dict
[
key
]
=
weight
.
cpu
()
new_state_dict
[
key
+
"_scale"
]
=
weight_scale
.
cpu
()
for
key
in
state_dict
.
keys
():
if
key
not
in
new_state_dict
.
keys
():
new_state_dict
[
key
]
=
state_dict
[
key
]
save_file
(
new_state_dict
,
new_model_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment