Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
6060ff4f
Commit
6060ff4f
authored
Jul 03, 2025
by
wangshankun
Browse files
Support:radial attention
parent
b2147c40
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
297 additions
and
48 deletions
+297
-48
configs/wan_i2v_audio.json
configs/wan_i2v_audio.json
+8
-10
lightx2v/attentions/__init__.py
lightx2v/attentions/__init__.py
+3
-0
lightx2v/attentions/common/radial_attn.py
lightx2v/attentions/common/radial_attn.py
+199
-0
lightx2v/common/ops/attn/attn_weight.py
lightx2v/common/ops/attn/attn_weight.py
+27
-2
lightx2v/infer.py
lightx2v/infer.py
+3
-1
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+1
-1
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+0
-2
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+11
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+4
-2
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+1
-1
lightx2v/models/networks/wan/lora_adapter.py
lightx2v/models/networks/wan/lora_adapter.py
+1
-0
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+3
-3
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+32
-26
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+3
-0
scripts/run_wan_i2v_audio.sh
scripts/run_wan_i2v_audio.sh
+1
-0
No files found.
configs/wan_i2v_audio.json
View file @
6060ff4f
{
{
"infer_steps"
:
8
,
"infer_steps"
:
5
,
"target_fps"
:
16
,
"video_duration"
:
12
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_height"
:
480
,
"target_width"
:
832
,
"target_width"
:
832
,
"attention_type"
:
"flash_attn3"
,
"self_attn_1_type"
:
"radial_attn"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"seed"
:
42
,
"sample_guide_scale"
:
1
,
"sample_guide_scale"
:
1
,
"sample_shift"
:
5
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"cpu_offload"
:
false
"feature_caching"
:
"Tea"
,
"coefficients"
:
[
[
8.10705460e03
,
2.13393892e03
,
-3.72934672e02
,
1.66203073e01
,
-4.17769401e-02
],
[
-114.36346466
,
65.26524496
,
-18.82220707
,
4.91518089
,
-0.23412683
]
],
"use_ret_steps"
:
true
,
"teacache_thresh"
:
0.12
}
}
lightx2v/attentions/__init__.py
View file @
6060ff4f
...
@@ -2,6 +2,7 @@ from lightx2v.attentions.common.torch_sdpa import torch_sdpa
...
@@ -2,6 +2,7 @@ from lightx2v.attentions.common.torch_sdpa import torch_sdpa
from
lightx2v.attentions.common.flash_attn2
import
flash_attn2
from
lightx2v.attentions.common.flash_attn2
import
flash_attn2
from
lightx2v.attentions.common.flash_attn3
import
flash_attn3
from
lightx2v.attentions.common.flash_attn3
import
flash_attn3
from
lightx2v.attentions.common.sage_attn2
import
sage_attn2
from
lightx2v.attentions.common.sage_attn2
import
sage_attn2
from
lightx2v.attentions.common.radial_attn
import
radial_attn
def
attention
(
attention_type
=
"flash_attn2"
,
*
args
,
**
kwargs
):
def
attention
(
attention_type
=
"flash_attn2"
,
*
args
,
**
kwargs
):
...
@@ -13,5 +14,7 @@ def attention(attention_type="flash_attn2", *args, **kwargs):
...
@@ -13,5 +14,7 @@ def attention(attention_type="flash_attn2", *args, **kwargs):
return
flash_attn3
(
*
args
,
**
kwargs
)
return
flash_attn3
(
*
args
,
**
kwargs
)
elif
attention_type
==
"sage_attn2"
:
elif
attention_type
==
"sage_attn2"
:
return
sage_attn2
(
*
args
,
**
kwargs
)
return
sage_attn2
(
*
args
,
**
kwargs
)
elif
attention_type
==
"radial_attn"
:
return
radial_attn
(
*
args
,
**
kwargs
)
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported attention mode:
{
attention_type
}
"
)
raise
NotImplementedError
(
f
"Unsupported attention mode:
{
attention_type
}
"
)
lightx2v/attentions/common/radial_attn.py
0 → 100644
View file @
6060ff4f
import
torch
import
flashinfer
###
### Code from radial-attention
### https://github.com/mit-han-lab/ç/blob/main/radial_attn/attn_mask.py#L150
###
def
radial_attn
(
query
,
key
,
value
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
mask_map
=
None
,
sparsity_type
=
"radial"
,
block_size
=
128
,
decay_factor
=
1
,
model_cls
=
"wan"
):
orig_seqlen
,
num_head
,
hidden_dim
=
query
.
shape
query
=
pad_qkv
(
query
,
block_size
=
block_size
)
key
=
pad_qkv
(
key
,
block_size
=
block_size
)
value
=
pad_qkv
(
value
,
block_size
=
block_size
)
mask
=
mask_map
.
queryLogMask
(
query
,
sparsity_type
,
block_size
=
block_size
,
decay_factor
=
decay_factor
,
model_type
=
model_cls
)
if
mask_map
else
None
seqlen
=
query
.
shape
[
0
]
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
device
=
query
.
device
,
dtype
=
torch
.
uint8
)
bsr_wrapper
=
flashinfer
.
BlockSparseAttentionWrapper
(
workspace_buffer
,
backend
=
"fa2"
,
)
indptr
=
get_indptr_from_mask
(
mask
,
query
)
indices
=
get_indices_from_mask
(
mask
,
query
)
bsr_wrapper
.
plan
(
indptr
=
indptr
,
indices
=
indices
,
M
=
seqlen
,
N
=
seqlen
,
R
=
block_size
,
C
=
block_size
,
num_qo_heads
=
num_head
,
num_kv_heads
=
num_head
,
head_dim
=
hidden_dim
,
q_data_type
=
query
.
dtype
,
kv_data_type
=
key
.
dtype
,
o_data_type
=
query
.
dtype
,
use_fp16_qk_reduction
=
True
,
)
o
=
bsr_wrapper
.
run
(
query
,
key
,
value
)
return
o
[:
orig_seqlen
,
:,
:]
def
get_indptr_from_mask
(
mask
,
query
):
# query shows the device of the indptr
# indptr (torch.Tensor) - the block index pointer of the block-sparse matrix on row dimension,
# shape `(MB + 1,)`, where `MB` is the number of blocks in the row dimension.
# The first element is always 0, and the last element is the number of blocks in the row dimension.
# The rest of the elements are the number of blocks in each row.
# the mask is already a block sparse mask
indptr
=
torch
.
zeros
(
mask
.
shape
[
0
]
+
1
,
device
=
query
.
device
,
dtype
=
torch
.
int32
)
indptr
[
0
]
=
0
row_counts
=
mask
.
sum
(
dim
=
1
).
flatten
()
# Ensure 1D output [num_blocks_row]
indptr
[
1
:]
=
torch
.
cumsum
(
row_counts
,
dim
=
0
)
return
indptr
def
get_indices_from_mask
(
mask
,
query
):
# indices (torch.Tensor) - the block indices of the block-sparse matrix on column dimension,
# shape `(nnz,),` where `nnz` is the number of non-zero blocks.
# The elements in `indices` array should be less than `NB`: the number of blocks in the column dimension.
nonzero_indices
=
torch
.
nonzero
(
mask
)
indices
=
nonzero_indices
[:,
1
].
to
(
dtype
=
torch
.
int32
,
device
=
query
.
device
)
return
indices
def
shrinkMaskStrict
(
mask
,
block_size
=
128
):
seqlen
=
mask
.
shape
[
0
]
block_num
=
seqlen
//
block_size
mask
=
mask
[:
block_num
*
block_size
,
:
block_num
*
block_size
].
view
(
block_num
,
block_size
,
block_num
,
block_size
)
col_densities
=
mask
.
sum
(
dim
=
1
)
/
block_size
# we want the minimum non-zero column density in the block
non_zero_densities
=
col_densities
>
0
high_density_cols
=
col_densities
>
1
/
3
frac_high_density_cols
=
high_density_cols
.
sum
(
dim
=-
1
)
/
(
non_zero_densities
.
sum
(
dim
=-
1
)
+
1e-9
)
block_mask
=
frac_high_density_cols
>
0.6
block_mask
[
0
:
0
]
=
True
block_mask
[
-
1
:
-
1
]
=
True
return
block_mask
def
pad_qkv
(
input_tensor
,
block_size
=
128
):
"""
Pad the input tensor to be a multiple of the block size.
input shape: (seqlen, num_heads, hidden_dim)
"""
seqlen
,
num_heads
,
hidden_dim
=
input_tensor
.
shape
# Calculate the necessary padding
padding_length
=
(
block_size
-
(
seqlen
%
block_size
))
%
block_size
# Create a padded tensor with zeros
padded_tensor
=
torch
.
zeros
((
seqlen
+
padding_length
,
num_heads
,
hidden_dim
),
device
=
input_tensor
.
device
,
dtype
=
input_tensor
.
dtype
)
# Copy the original tensor into the padded tensor
padded_tensor
[:
seqlen
,
:,
:]
=
input_tensor
return
padded_tensor
def
get_diagonal_split_mask
(
i
,
j
,
token_per_frame
,
sparse_type
,
query
):
assert
sparse_type
in
[
"radial"
]
dist
=
abs
(
i
-
j
)
group
=
dist
.
bit_length
()
threshold
=
128
# hardcoded threshold for now, which is equal to block-size
decay_length
=
2
**
token_per_frame
.
bit_length
()
/
2
**
group
if
decay_length
>=
threshold
:
return
torch
.
ones
((
token_per_frame
,
token_per_frame
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
split_factor
=
int
(
threshold
/
decay_length
)
modular
=
dist
%
split_factor
if
modular
==
0
:
return
torch
.
ones
((
token_per_frame
,
token_per_frame
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
else
:
return
torch
.
zeros
((
token_per_frame
,
token_per_frame
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
def
get_window_width
(
i
,
j
,
token_per_frame
,
sparse_type
,
num_frame
,
decay_factor
=
1
,
block_size
=
128
,
model_type
=
None
):
assert
sparse_type
in
[
"radial"
]
dist
=
abs
(
i
-
j
)
if
model_type
==
"wan"
:
if
dist
<
1
:
return
token_per_frame
if
dist
==
1
:
return
token_per_frame
//
2
elif
model_type
==
"hunyuan"
:
if
dist
<=
1
:
return
token_per_frame
else
:
raise
ValueError
(
f
"Unknown model type:
{
model_type
}
"
)
group
=
dist
.
bit_length
()
decay_length
=
2
**
token_per_frame
.
bit_length
()
/
2
**
group
*
decay_factor
threshold
=
block_size
if
decay_length
>=
threshold
:
return
decay_length
else
:
return
threshold
def
gen_log_mask_shrinked
(
query
,
s
,
video_token_num
,
num_frame
,
block_size
=
128
,
sparse_type
=
"log"
,
decay_factor
=
0.5
,
model_type
=
None
):
"""
A more memory friendly version, we generate the attention mask of each frame pair at a time,
shrinks it, and stores it into the final result
"""
final_log_mask
=
torch
.
zeros
((
s
//
block_size
,
s
//
block_size
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
token_per_frame
=
video_token_num
//
num_frame
video_text_border
=
video_token_num
//
block_size
col_indices
=
torch
.
arange
(
0
,
token_per_frame
,
device
=
query
.
device
).
view
(
1
,
-
1
)
row_indices
=
torch
.
arange
(
0
,
token_per_frame
,
device
=
query
.
device
).
view
(
-
1
,
1
)
final_log_mask
[
video_text_border
:]
=
True
final_log_mask
[:,
video_text_border
:]
=
True
for
i
in
range
(
num_frame
):
for
j
in
range
(
num_frame
):
local_mask
=
torch
.
zeros
((
token_per_frame
,
token_per_frame
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
if
j
==
0
:
# this is attention sink
local_mask
=
torch
.
ones
((
token_per_frame
,
token_per_frame
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
else
:
window_width
=
get_window_width
(
i
,
j
,
token_per_frame
,
sparse_type
,
num_frame
,
decay_factor
=
decay_factor
,
block_size
=
block_size
,
model_type
=
model_type
)
local_mask
=
torch
.
abs
(
col_indices
-
row_indices
)
<=
window_width
split_mask
=
get_diagonal_split_mask
(
i
,
j
,
token_per_frame
,
sparse_type
,
query
)
local_mask
=
torch
.
logical_and
(
local_mask
,
split_mask
)
remainder_row
=
(
i
*
token_per_frame
)
%
block_size
remainder_col
=
(
j
*
token_per_frame
)
%
block_size
# get the padded size
all_length_row
=
remainder_row
+
((
token_per_frame
-
1
)
//
block_size
+
1
)
*
block_size
all_length_col
=
remainder_col
+
((
token_per_frame
-
1
)
//
block_size
+
1
)
*
block_size
padded_local_mask
=
torch
.
zeros
((
all_length_row
,
all_length_col
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
padded_local_mask
[
remainder_row
:
remainder_row
+
token_per_frame
,
remainder_col
:
remainder_col
+
token_per_frame
]
=
local_mask
# shrink the mask
block_mask
=
shrinkMaskStrict
(
padded_local_mask
,
block_size
=
block_size
)
# set the block mask to the final log mask
block_row_start
=
(
i
*
token_per_frame
)
//
block_size
block_col_start
=
(
j
*
token_per_frame
)
//
block_size
block_row_end
=
block_row_start
+
block_mask
.
shape
[
0
]
block_col_end
=
block_col_start
+
block_mask
.
shape
[
1
]
final_log_mask
[
block_row_start
:
block_row_end
,
block_col_start
:
block_col_end
]
=
torch
.
logical_or
(
final_log_mask
[
block_row_start
:
block_row_end
,
block_col_start
:
block_col_end
],
block_mask
)
print
(
f
"mask sparsity:
{
1
-
final_log_mask
.
sum
()
/
final_log_mask
.
numel
()
}
"
)
return
final_log_mask
class
MaskMap
:
def
__init__
(
self
,
video_token_num
=
79200
,
num_frame
=
22
):
self
.
video_token_num
=
video_token_num
self
.
num_frame
=
num_frame
self
.
log_mask
=
None
def
queryLogMask
(
self
,
query
,
sparse_type
,
block_size
=
128
,
decay_factor
=
0.5
,
model_type
=
None
):
log_mask
=
torch
.
ones
((
query
.
shape
[
0
]
//
block_size
,
query
.
shape
[
0
]
//
block_size
),
device
=
query
.
device
,
dtype
=
torch
.
bool
)
if
self
.
log_mask
is
None
:
self
.
log_mask
=
gen_log_mask_shrinked
(
query
,
query
.
shape
[
0
],
self
.
video_token_num
,
self
.
num_frame
,
sparse_type
=
sparse_type
,
decay_factor
=
decay_factor
,
model_type
=
model_type
,
block_size
=
block_size
)
block_bound
=
self
.
video_token_num
//
block_size
log_mask
[:
block_bound
,
:
block_bound
]
=
self
.
log_mask
[:
block_bound
,
:
block_bound
]
return
log_mask
lightx2v/common/ops/attn/attn_weight.py
View file @
6060ff4f
...
@@ -37,6 +37,9 @@ else:
...
@@ -37,6 +37,9 @@ else:
sageattn
=
None
sageattn
=
None
from
lightx2v.attentions.common.radial_attn
import
radial_attn
class
AttnWeightTemplate
(
metaclass
=
ABCMeta
):
class
AttnWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
):
def
__init__
(
self
,
weight_name
):
self
.
weight_name
=
weight_name
self
.
weight_name
=
weight_name
...
@@ -70,7 +73,7 @@ class FlashAttn2Weight(AttnWeightTemplate):
...
@@ -70,7 +73,7 @@ class FlashAttn2Weight(AttnWeightTemplate):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
config
=
{}
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
):
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
):
x
=
flash_attn_varlen_func
(
x
=
flash_attn_varlen_func
(
q
,
q
,
k
,
k
,
...
@@ -88,7 +91,7 @@ class FlashAttn3Weight(AttnWeightTemplate):
...
@@ -88,7 +91,7 @@ class FlashAttn3Weight(AttnWeightTemplate):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
config
=
{}
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
):
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
):
x
=
flash_attn_varlen_func_v3
(
x
=
flash_attn_varlen_func_v3
(
q
,
q
,
k
,
k
,
...
@@ -101,6 +104,28 @@ class FlashAttn3Weight(AttnWeightTemplate):
...
@@ -101,6 +104,28 @@ class FlashAttn3Weight(AttnWeightTemplate):
return
x
return
x
@
ATTN_WEIGHT_REGISTER
(
"radial_attn"
)
class
RadialAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
mask_map
=
None
,
sparsity_type
=
"radial"
,
block_size
=
128
,
decay_factor
=
1
,
model_cls
=
"wan"
):
assert
len
(
q
.
shape
)
==
3
x
=
radial_attn
(
q
,
k
,
v
,
mask_map
=
mask_map
,
sparsity_type
=
sparsity_type
,
block_size
=
block_size
,
model_cls
=
model_cls
[:
3
],
# Use first 3 characters to match "wan", "wan2", etc.
decay_factor
=
decay_factor
,
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
return
x
@
ATTN_WEIGHT_REGISTER
(
"sage_attn2"
)
@
ATTN_WEIGHT_REGISTER
(
"sage_attn2"
)
class
SageAttn2Weight
(
AttnWeightTemplate
):
class
SageAttn2Weight
(
AttnWeightTemplate
):
def
__init__
(
self
):
def
__init__
(
self
):
...
...
lightx2v/infer.py
View file @
6060ff4f
...
@@ -42,7 +42,9 @@ def init_runner(config):
...
@@ -42,7 +42,9 @@ def init_runner(config):
async
def
main
():
async
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
],
default
=
"hunyuan"
)
parser
.
add_argument
(
"--model_cls"
,
type
=
str
,
required
=
True
,
choices
=
[
"wan2.1"
,
"hunyuan"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_skyreels_v2_df"
,
"cogvideox"
,
"wan2.1_audio"
],
default
=
"hunyuan"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
...
...
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
6060ff4f
...
@@ -449,7 +449,7 @@ class WanVideoIPHandler:
...
@@ -449,7 +449,7 @@ class WanVideoIPHandler:
but Wan2.1 official use no_crop resize by default
but Wan2.1 official use no_crop resize by default
so I don't use CLIPImageProcessor
so I don't use CLIPImageProcessor
"""
"""
image_encoder
=
CLIPVisionModel
.
from_pretrained
(
repo_or_path
,
subfolder
=
"image_encoder"
,
torch_dtype
=
dtype
)
image_encoder
=
CLIPVisionModel
.
from_pretrained
(
repo_or_path
,
torch_dtype
=
dtype
)
logger
.
info
(
f
"Using image encoder
{
model_name
}
from
{
repo_or_path
}
"
)
logger
.
info
(
f
"Using image encoder
{
model_name
}
from
{
repo_or_path
}
"
)
image_encoder
.
requires_grad_
(
require_grad
)
image_encoder
.
requires_grad_
(
require_grad
)
if
mode
==
"eval"
:
if
mode
==
"eval"
:
...
...
lightx2v/models/networks/wan/audio_adapter.py
View file @
6060ff4f
...
@@ -7,11 +7,9 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
...
@@ -7,11 +7,9 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers
import
AutoModel
from
transformers
import
AutoModel
from
loguru
import
logger
from
loguru
import
logger
import
pdb
import
os
import
os
import
safetensors
import
safetensors
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
6060ff4f
...
@@ -19,6 +19,8 @@ from safetensors import safe_open
...
@@ -19,6 +19,8 @@ from safetensors import safe_open
import
lightx2v.attentions.distributed.ulysses.wrap
as
ulysses_dist_wrap
import
lightx2v.attentions.distributed.ulysses.wrap
as
ulysses_dist_wrap
import
lightx2v.attentions.distributed.ring.wrap
as
ring_dist_wrap
import
lightx2v.attentions.distributed.ring.wrap
as
ring_dist_wrap
from
lightx2v.attentions.common.radial_attn
import
MaskMap
from
lightx2v.models.networks.wan.infer.transformer_infer
import
(
from
lightx2v.models.networks.wan.infer.transformer_infer
import
(
WanTransformerInfer
,
WanTransformerInfer
,
)
)
...
@@ -51,6 +53,15 @@ class WanAudioModel(WanModel):
...
@@ -51,6 +53,15 @@ class WanAudioModel(WanModel):
self
.
pre_weight
.
to_cuda
()
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
num_frame
=
c
+
1
# for r2v
video_token_num
=
num_frame
*
(
h
//
2
)
*
(
w
//
2
)
from
loguru
import
logger
logger
.
info
(
f
"video_token_num:
{
video_token_num
}
, num_frame:
{
num_frame
}
"
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
num_frame
)
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
embed
,
grid_sizes
,
pre_infer_out
,
valid_patch_length
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
,
valid_patch_length
)[
0
]
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
6060ff4f
...
@@ -7,7 +7,6 @@ from lightx2v.common.offload.manager import (
...
@@ -7,7 +7,6 @@ from lightx2v.common.offload.manager import (
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
loguru
import
logger
from
loguru
import
logger
import
pdb
import
os
import
os
...
@@ -24,6 +23,8 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -24,6 +23,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
parallel_attention
=
None
self
.
parallel_attention
=
None
self
.
apply_rotary_emb_func
=
apply_rotary_emb_chunk
if
config
.
get
(
"rotary_chunk"
,
False
)
else
apply_rotary_emb
self
.
apply_rotary_emb_func
=
apply_rotary_emb_chunk
if
config
.
get
(
"rotary_chunk"
,
False
)
else
apply_rotary_emb
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
mask_map
=
None
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
if
"offload_ratio"
in
self
.
config
:
if
"offload_ratio"
in
self
.
config
:
offload_ratio
=
self
.
config
[
"offload_ratio"
]
offload_ratio
=
self
.
config
[
"offload_ratio"
]
...
@@ -290,7 +291,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -290,7 +291,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if
not
self
.
parallel_attention
:
if
not
self
.
parallel_attention
:
if
self
.
config
.
get
(
"audio_sr"
,
False
):
if
self
.
config
.
get
(
"audio_sr"
,
False
):
freqs_i
=
compute_freqs_audio
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
compute_freqs_audio
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
else
:
else
:
...
@@ -321,6 +322,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -321,6 +322,7 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_q
=
q
.
size
(
0
),
max_seqlen_q
=
q
.
size
(
0
),
max_seqlen_kv
=
k
.
size
(
0
),
max_seqlen_kv
=
k
.
size
(
0
),
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
mask_map
=
self
.
mask_map
,
)
)
else
:
else
:
attn_out
=
self
.
parallel_attention
(
attn_out
=
self
.
parallel_attention
(
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
6060ff4f
...
@@ -23,7 +23,7 @@ def compute_freqs(c, grid_sizes, freqs):
...
@@ -23,7 +23,7 @@ def compute_freqs(c, grid_sizes, freqs):
def
compute_freqs_audio
(
c
,
grid_sizes
,
freqs
):
def
compute_freqs_audio
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
f
=
f
+
1
f
=
f
+
1
##for r2v add 1 channel
seq_len
=
f
*
h
*
w
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
freqs_i
=
torch
.
cat
(
[
[
...
...
lightx2v/models/networks/wan/lora_adapter.py
View file @
6060ff4f
...
@@ -50,6 +50,7 @@ class WanLoraWrapper:
...
@@ -50,6 +50,7 @@ class WanLoraWrapper:
self
.
model
.
_init_weights
(
weight_dict
)
self
.
model
.
_init_weights
(
weight_dict
)
logger
.
info
(
f
"Applied LoRA:
{
lora_name
}
with alpha=
{
alpha
}
"
)
logger
.
info
(
f
"Applied LoRA:
{
lora_name
}
with alpha=
{
alpha
}
"
)
del
lora_weights
# 删除节约显存
return
True
return
True
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
6060ff4f
...
@@ -184,7 +184,7 @@ class WanSelfAttention(WeightModule):
...
@@ -184,7 +184,7 @@ class WanSelfAttention(WeightModule):
sparge_ckpt
=
torch
.
load
(
self
.
config
[
"sparge_ckpt"
])
sparge_ckpt
=
torch
.
load
(
self
.
config
[
"sparge_ckpt"
])
self
.
self_attn_1
.
load
(
sparge_ckpt
)
self
.
self_attn_1
.
load
(
sparge_ckpt
)
else
:
else
:
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
attention
_type"
]]())
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
self_attn_1
_type"
]]())
if
self
.
quant_method
in
[
"smoothquant"
,
"awq"
]:
if
self
.
quant_method
in
[
"smoothquant"
,
"awq"
]:
self
.
add_module
(
self
.
add_module
(
"smooth_norm1_weight"
,
"smooth_norm1_weight"
,
...
@@ -275,7 +275,7 @@ class WanCrossAttention(WeightModule):
...
@@ -275,7 +275,7 @@ class WanCrossAttention(WeightModule):
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
attention
_type"
]]())
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
cross_attn_1
_type"
]]())
if
self
.
config
.
task
==
"i2v"
:
if
self
.
config
.
task
==
"i2v"
:
self
.
add_module
(
self
.
add_module
(
...
@@ -304,7 +304,7 @@ class WanCrossAttention(WeightModule):
...
@@ -304,7 +304,7 @@ class WanCrossAttention(WeightModule):
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
"cross_attn_2"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
attention
_type"
]]())
self
.
add_module
(
"cross_attn_2"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"
cross_attn_2
_type"
]]())
class
WanFFN
(
WeightModule
):
class
WanFFN
(
WeightModule
):
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
6060ff4f
...
@@ -31,7 +31,6 @@ from torchvision.transforms.functional import resize
...
@@ -31,7 +31,6 @@ from torchvision.transforms.functional import resize
import
subprocess
import
subprocess
import
warnings
import
warnings
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
pdb
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
...
@@ -210,7 +209,6 @@ def generate_unique_path(path):
...
@@ -210,7 +209,6 @@ def generate_unique_path(path):
def
save_to_video
(
gen_lvideo
,
out_path
,
target_fps
):
def
save_to_video
(
gen_lvideo
,
out_path
,
target_fps
):
print
(
gen_lvideo
.
shape
)
gen_lvideo
=
rearrange
(
gen_lvideo
,
"B C T H W -> B T H W C"
)
gen_lvideo
=
rearrange
(
gen_lvideo
,
"B C T H W -> B T H W C"
)
gen_lvideo
=
(
gen_lvideo
[
0
].
cpu
().
numpy
()
*
127.5
+
127.5
).
astype
(
np
.
uint8
)
gen_lvideo
=
(
gen_lvideo
[
0
].
cpu
().
numpy
()
*
127.5
+
127.5
).
astype
(
np
.
uint8
)
gen_lvideo
=
gen_lvideo
[...,
::
-
1
].
copy
()
gen_lvideo
=
gen_lvideo
[...,
::
-
1
].
copy
()
...
@@ -219,13 +217,13 @@ def save_to_video(gen_lvideo, out_path, target_fps):
...
@@ -219,13 +217,13 @@ def save_to_video(gen_lvideo, out_path, target_fps):
def
save_audio
(
def
save_audio
(
audio_array
:
str
,
audio_array
,
audio_name
:
str
,
audio_name
:
str
,
video_name
:
str
=
None
,
video_name
:
str
=
None
,
sr
:
int
=
16000
,
sr
:
int
=
16000
,
):
):
logger
.
info
(
f
"Saving audio to
{
audio_name
}
type:
{
type
(
audio_array
)
}
"
)
logger
.
info
(
f
"Saving audio to
{
audio_name
}
type:
{
type
(
audio_array
)
}
"
)
if
not
os
.
path
.
exists
(
audio_name
):
ta
.
save
(
ta
.
save
(
audio_name
,
audio_name
,
torch
.
tensor
(
audio_array
[
None
]),
torch
.
tensor
(
audio_array
[
None
]),
...
@@ -233,7 +231,15 @@ def save_audio(
...
@@ -233,7 +231,15 @@ def save_audio(
)
)
out_video
=
f
"
{
video_name
[:
-
4
]
}
_with_audio.mp4"
out_video
=
f
"
{
video_name
[:
-
4
]
}
_with_audio.mp4"
# generate_unique_path(out_path)
# 确保父目录存在
parent_dir
=
os
.
path
.
dirname
(
out_video
)
if
parent_dir
and
not
os
.
path
.
exists
(
parent_dir
):
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
# 如果输出视频已存在,先删除
if
os
.
path
.
exists
(
out_video
):
os
.
remove
(
out_video
)
cmd
=
f
"/usr/bin/ffmpeg -i
{
video_name
}
-i
{
audio_name
}
{
out_video
}
"
cmd
=
f
"/usr/bin/ffmpeg -i
{
video_name
}
-i
{
audio_name
}
{
out_video
}
"
subprocess
.
call
(
cmd
,
shell
=
True
)
subprocess
.
call
(
cmd
,
shell
=
True
)
...
@@ -246,6 +252,9 @@ class WanAudioRunner(WanRunner):
...
@@ -246,6 +252,9 @@ class WanAudioRunner(WanRunner):
def
load_audio_models
(
self
):
def
load_audio_models
(
self
):
##音频特征提取器
##音频特征提取器
self
.
audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
self
.
audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
##音频驱动视频生成adapter
audio_adapter_path
=
self
.
config
[
"model_path"
]
+
"/audio_adapter.safetensors"
audio_adaper
=
AudioAdapter
.
from_transformer
(
audio_adaper
=
AudioAdapter
.
from_transformer
(
self
.
model
,
self
.
model
,
audio_feature_dim
=
1024
,
audio_feature_dim
=
1024
,
...
@@ -253,11 +262,11 @@ class WanAudioRunner(WanRunner):
...
@@ -253,11 +262,11 @@ class WanAudioRunner(WanRunner):
time_freq_dim
=
256
,
time_freq_dim
=
256
,
projection_transformer_layers
=
4
,
projection_transformer_layers
=
4
,
)
)
load_path
=
"/mnt/aigc/zoemodels/Zoetrained/vigendit/audio_driven/audio_adapter/audio_adapter_V1_0507_bf16.safetensors"
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adaper
,
audio_adapter_path
,
strict
=
False
)
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adaper
,
load_path
,
strict
=
False
)
##音频特征编码器
device
=
self
.
model
.
device
device
=
self
.
model
.
device
audio_encoder_repo
=
"/mnt/aigc/zoemodels/models--TencentGameMate--chinese-hubert-large/snapshots/90cb660492214f687e60f5ca509b20edae6e75bd
"
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder
"
audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
generator
=
torch
.
Generator
(
device
),
weight
=
1.0
)
audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
generator
=
torch
.
Generator
(
device
),
weight
=
1.0
)
return
audio_adapter_pipe
return
audio_adapter_pipe
...
@@ -275,9 +284,8 @@ class WanAudioRunner(WanRunner):
...
@@ -275,9 +284,8 @@ class WanAudioRunner(WanRunner):
return
base_model
return
base_model
def
load_image_encoder
(
self
):
def
load_image_encoder
(
self
):
image_encoder
=
WanVideoIPHandler
(
clip_model_dir
=
self
.
config
[
"model_path"
]
+
"/image_encoder"
"CLIPModel"
,
repo_or_path
=
"/mnt/aigc/zoemodels/Wan21/Wan2.1-I2V-14B-720P-Diffusers"
,
require_grad
=
False
,
mode
=
"eval"
,
device
=
self
.
init_device
,
dtype
=
torch
.
float16
image_encoder
=
WanVideoIPHandler
(
"CLIPModel"
,
repo_or_path
=
clip_model_dir
,
require_grad
=
False
,
mode
=
"eval"
,
device
=
self
.
init_device
,
dtype
=
torch
.
float16
)
)
return
image_encoder
return
image_encoder
...
@@ -325,6 +333,7 @@ class WanAudioRunner(WanRunner):
...
@@ -325,6 +333,7 @@ class WanAudioRunner(WanRunner):
self
.
set_target_shape
()
self
.
set_target_shape
()
self
.
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
self
.
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
del
self
.
image_encoder
# 删除ref的clip模型,只使用一次
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -360,15 +369,15 @@ class WanAudioRunner(WanRunner):
...
@@ -360,15 +369,15 @@ class WanAudioRunner(WanRunner):
self
.
inputs
[
"audio_adapter_pipe"
]
=
self
.
load_audio_models
()
self
.
inputs
[
"audio_adapter_pipe"
]
=
self
.
load_audio_models
()
# process audio
# process audio
audio_sr
=
16000
audio_sr
=
self
.
config
.
get
(
"audio_sr"
,
16000
)
max_num_frames
=
81
# wan2.1一段最多81帧,5秒,16fps
max_num_frames
=
self
.
config
.
get
(
"target_video_length"
,
81
)
# wan2.1一段最多81帧,5秒,16fps
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
# 音视频同步帧率
target_fps
=
self
.
config
.
get
(
"target_fps"
,
16
)
# 音视频同步帧率
video_duration
=
self
.
config
.
get
(
"video_duration"
,
8
)
# 期望视频输出时长
video_duration
=
self
.
config
.
get
(
"video_duration"
,
5
)
# 期望视频输出时长
audio_array
=
load_audio
(
self
.
config
[
"audio_path"
],
sr
=
audio_sr
)
audio_array
=
load_audio
(
self
.
config
[
"audio_path"
],
sr
=
audio_sr
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
audio_sr
*
target_fps
)
audio_len
=
int
(
audio_array
.
shape
[
0
]
/
audio_sr
*
target_fps
)
prev_frame_length
=
5
prev_frame_length
=
5
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
target_fps
*
16000
)
max_num_audio_length
=
int
((
max_num_frames
+
1
)
/
target_fps
*
audio_sr
)
interval_num
=
1
interval_num
=
1
# expected_frames
# expected_frames
...
@@ -463,13 +472,10 @@ class WanAudioRunner(WanRunner):
...
@@ -463,13 +472,10 @@ class WanAudioRunner(WanRunner):
latents
=
self
.
model
.
scheduler
.
latents
latents
=
self
.
model
.
scheduler
.
latents
generator
=
self
.
model
.
scheduler
.
generator
generator
=
self
.
model
.
scheduler
.
generator
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
# gen_img = vae_handler.decode(xt.to(vae_dtype))
# B, C, T, H, W
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
)
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
)
start_frame
=
0
if
idx
==
0
else
prev_frame_length
start_frame
=
0
if
idx
==
0
else
prev_frame_length
start_audio_frame
=
0
if
idx
==
0
else
int
((
prev_frame_length
+
1
)
*
audio_sr
/
target_fps
)
start_audio_frame
=
0
if
idx
==
0
else
int
((
prev_frame_length
+
1
)
*
audio_sr
/
target_fps
)
print
(
f
"----
{
idx
}
,
{
gen_video
[:,
:,
start_frame
:].
shape
}
"
)
if
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
if
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
res_frame_num
])
gen_video_list
.
append
(
gen_video
[:,
:,
start_frame
:
res_frame_num
])
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
cut_audio_list
.
append
(
audio_array
[
start_audio_frame
:
useful_length
])
...
@@ -482,7 +488,7 @@ class WanAudioRunner(WanRunner):
...
@@ -482,7 +488,7 @@ class WanAudioRunner(WanRunner):
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
gen_lvideo
=
torch
.
cat
(
gen_video_list
,
dim
=
2
).
float
()
merge_audio
=
np
.
concatenate
(
cut_audio_list
,
axis
=
0
).
astype
(
np
.
float32
)
merge_audio
=
np
.
concatenate
(
cut_audio_list
,
axis
=
0
).
astype
(
np
.
float32
)
out_path
=
os
.
path
.
join
(
"./"
,
"video_merge.mp4"
)
out_path
=
self
.
config
.
save_video_path
audio_file
=
os
.
path
.
join
(
"./"
,
"audio_merge.wav"
)
audio_file
=
os
.
path
.
join
(
"./"
,
"audio_merge.wav"
)
save_to_video
(
gen_lvideo
,
out_path
,
target_fps
)
save_to_video
(
gen_lvideo
,
out_path
,
target_fps
)
save_audio
(
merge_audio
,
audio_file
,
out_path
)
save_audio
(
merge_audio
,
audio_file
,
out_path
)
...
@@ -501,5 +507,5 @@ class WanAudioRunner(WanRunner):
...
@@ -501,5 +507,5 @@ class WanAudioRunner(WanRunner):
self
.
run
()
self
.
run
()
self
.
end_run
()
self
.
end_run
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
lightx2v/models/schedulers/wan/scheduler.py
View file @
6060ff4f
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
gc
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
...
@@ -123,6 +124,8 @@ class WanScheduler(BaseScheduler):
...
@@ -123,6 +124,8 @@ class WanScheduler(BaseScheduler):
self
.
this_order
=
None
self
.
this_order
=
None
self
.
lower_order_nums
=
0
self
.
lower_order_nums
=
0
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
multistep_uni_p_bh_update
(
def
multistep_uni_p_bh_update
(
self
,
self
,
...
...
scripts/run_wan_i2v_audio.sh
View file @
6060ff4f
...
@@ -27,6 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
...
@@ -27,6 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
export
ENABLE_GRAPH_MODE
=
false
export
DTYPE
=
BF16
python
-m
lightx2v.infer
\
python
-m
lightx2v.infer
\
--model_cls
wan2.1_audio
\
--model_cls
wan2.1_audio
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment