Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
048be946
Commit
048be946
authored
Jul 23, 2025
by
gaclove
Browse files
feat: add adaptive resizing configuration and implement new resizing functions in WanAudioRunner
parent
d048b178
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
153 additions
and
83 deletions
+153
-83
configs/audio_driven/wan_i2v_audio.json
configs/audio_driven/wan_i2v_audio.json
+15
-15
configs/audio_driven/wan_i2v_audio_adaptive_resize.json
configs/audio_driven/wan_i2v_audio_adaptive_resize.json
+18
-0
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+120
-68
No files found.
configs/audio_driven/wan_i2v_audio.json
View file @
048be946
{
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
16
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_height"
:
48
0
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
1
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
16
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_height"
:
72
0
,
"target_width"
:
1280
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
1
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
}
configs/audio_driven/wan_i2v_audio_adaptive_resize.json
0 → 100644
View file @
048be946
{
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
16
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_height"
:
720
,
"target_width"
:
1280
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
1
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"adaptive_resize"
:
true
}
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
048be946
...
...
@@ -51,6 +51,90 @@ def memory_efficient_inference():
gc
.
collect
()
def
optimize_latent_size_with_sp
(
lat_h
,
lat_w
,
sp_size
,
patch_size
):
patched_h
,
patched_w
=
lat_h
//
patch_size
[
0
],
lat_w
//
patch_size
[
1
]
if
(
patched_h
*
patched_w
)
%
sp_size
==
0
:
return
lat_h
,
lat_w
else
:
h_ratio
,
w_ratio
=
1
,
1
h_noevenly_n
,
w_noevenly_n
=
0
,
0
h_backup
,
w_backup
=
patched_h
,
patched_w
while
sp_size
//
2
!=
1
:
if
h_backup
%
2
==
0
:
h_backup
//=
2
h_ratio
*=
2
elif
w_backup
%
2
==
0
:
w_backup
//=
2
w_ratio
*=
2
elif
h_noevenly_n
<=
w_noevenly_n
:
h_backup
//=
2
h_ratio
*=
2
h_noevenly_n
+=
1
else
:
w_backup
//=
2
w_ratio
*=
2
w_noevenly_n
+=
1
sp_size
//=
2
new_lat_h
=
lat_h
//
h_ratio
*
h_ratio
new_lat_w
=
lat_w
//
w_ratio
*
w_ratio
return
new_lat_h
,
new_lat_w
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
tgt_ar
=
tgt_h
/
tgt_w
ori_ar
=
ori_h
/
ori_w
if
abs
(
ori_ar
-
tgt_ar
)
<
0.01
:
return
0
,
ori_h
,
0
,
ori_w
if
ori_ar
>
tgt_ar
:
crop_h
=
int
(
tgt_ar
*
ori_w
)
y0
=
(
ori_h
-
crop_h
)
//
2
y1
=
y0
+
crop_h
return
y0
,
y1
,
0
,
ori_w
else
:
crop_w
=
int
(
ori_h
/
tgt_ar
)
x0
=
(
ori_w
-
crop_w
)
//
2
x1
=
x0
+
crop_w
return
0
,
ori_h
,
x0
,
x1
def
isotropic_crop_resize
(
frames
:
torch
.
Tensor
,
size
:
tuple
):
"""
frames: (T, C, H, W)
size: (H, W)
"""
ori_h
,
ori_w
=
frames
.
shape
[
2
:]
h
,
w
=
size
y0
,
y1
,
x0
,
x1
=
get_crop_bbox
(
ori_h
,
ori_w
,
h
,
w
)
cropped_frames
=
frames
[:,
:,
y0
:
y1
,
x0
:
x1
]
resized_frames
=
resize
(
cropped_frames
,
size
,
InterpolationMode
.
BICUBIC
,
antialias
=
True
)
return
resized_frames
def
adaptive_resize
(
img
):
bucket_config
=
{
0.667
:
(
np
.
array
([[
480
,
832
],
[
544
,
960
],
[
720
,
1280
]],
dtype
=
np
.
int64
),
np
.
array
([
0.2
,
0.5
,
0.3
])),
1.0
:
(
np
.
array
([[
480
,
480
],
[
576
,
576
],
[
704
,
704
],
[
960
,
960
]],
dtype
=
np
.
int64
),
np
.
array
([
0.1
,
0.1
,
0.5
,
0.3
])),
1.5
:
(
np
.
array
([[
480
,
832
],
[
544
,
960
],
[
720
,
1280
]],
dtype
=
np
.
int64
)[:,
::
-
1
],
np
.
array
([
0.2
,
0.5
,
0.3
])),
}
ori_height
=
img
.
shape
[
-
2
]
ori_weight
=
img
.
shape
[
-
1
]
ori_ratio
=
ori_height
/
ori_weight
aspect_ratios
=
np
.
array
(
np
.
array
(
list
(
bucket_config
.
keys
())))
closet_aspect_idx
=
np
.
argmin
(
np
.
abs
(
aspect_ratios
-
ori_ratio
))
closet_ratio
=
aspect_ratios
[
closet_aspect_idx
]
if
ori_ratio
<
1.0
:
target_h
,
target_w
=
480
,
832
elif
ori_ratio
==
1.0
:
target_h
,
target_w
=
480
,
480
else
:
target_h
,
target_w
=
832
,
480
for
resolution
in
bucket_config
[
closet_ratio
][
0
]:
if
ori_height
*
ori_weight
>=
resolution
[
0
]
*
resolution
[
1
]:
target_h
,
target_w
=
resolution
cropped_img
=
isotropic_crop_resize
(
img
,
(
target_h
,
target_w
))
return
cropped_img
,
target_h
,
target_w
@
dataclass
class
AudioSegment
:
"""Data class for audio segment information"""
...
...
@@ -300,61 +384,6 @@ class VideoGenerator:
return
gen_video
def
get_crop_bbox
(
ori_h
,
ori_w
,
tgt_h
,
tgt_w
):
tgt_ar
=
tgt_h
/
tgt_w
ori_ar
=
ori_h
/
ori_w
if
abs
(
ori_ar
-
tgt_ar
)
<
0.01
:
return
0
,
ori_h
,
0
,
ori_w
if
ori_ar
>
tgt_ar
:
crop_h
=
int
(
tgt_ar
*
ori_w
)
y0
=
(
ori_h
-
crop_h
)
//
2
y1
=
y0
+
crop_h
return
y0
,
y1
,
0
,
ori_w
else
:
crop_w
=
int
(
ori_h
/
tgt_ar
)
x0
=
(
ori_w
-
crop_w
)
//
2
x1
=
x0
+
crop_w
return
0
,
ori_h
,
x0
,
x1
def
isotropic_crop_resize
(
frames
:
torch
.
Tensor
,
size
:
tuple
):
"""
frames: (T, C, H, W)
size: (H, W)
"""
ori_h
,
ori_w
=
frames
.
shape
[
2
:]
h
,
w
=
size
y0
,
y1
,
x0
,
x1
=
get_crop_bbox
(
ori_h
,
ori_w
,
h
,
w
)
cropped_frames
=
frames
[:,
:,
y0
:
y1
,
x0
:
x1
]
resized_frames
=
resize
(
cropped_frames
,
size
,
InterpolationMode
.
BICUBIC
,
antialias
=
True
)
return
resized_frames
def
adaptive_resize
(
img
):
bucket_config
=
{
0.667
:
(
np
.
array
([[
480
,
832
],
[
544
,
960
],
[
720
,
1280
]],
dtype
=
np
.
int64
),
np
.
array
([
0.2
,
0.5
,
0.3
])),
1.0
:
(
np
.
array
([[
480
,
480
],
[
576
,
576
],
[
704
,
704
],
[
960
,
960
]],
dtype
=
np
.
int64
),
np
.
array
([
0.1
,
0.1
,
0.5
,
0.3
])),
1.5
:
(
np
.
array
([[
480
,
832
],
[
544
,
960
],
[
720
,
1280
]],
dtype
=
np
.
int64
)[:,
::
-
1
],
np
.
array
([
0.2
,
0.5
,
0.3
])),
}
ori_height
=
img
.
shape
[
-
2
]
ori_weight
=
img
.
shape
[
-
1
]
ori_ratio
=
ori_height
/
ori_weight
aspect_ratios
=
np
.
array
(
np
.
array
(
list
(
bucket_config
.
keys
())))
closet_aspect_idx
=
np
.
argmin
(
np
.
abs
(
aspect_ratios
-
ori_ratio
))
closet_ratio
=
aspect_ratios
[
closet_aspect_idx
]
if
ori_ratio
<
1.0
:
target_h
,
target_w
=
480
,
832
elif
ori_ratio
==
1.0
:
target_h
,
target_w
=
480
,
480
else
:
target_h
,
target_w
=
832
,
480
for
resolution
in
bucket_config
[
closet_ratio
][
0
]:
if
ori_height
*
ori_weight
>=
resolution
[
0
]
*
resolution
[
1
]:
target_h
,
target_w
=
resolution
cropped_img
=
isotropic_crop_resize
(
img
,
(
target_h
,
target_w
))
return
cropped_img
,
target_h
,
target_w
@
RUNNER_REGISTER
(
"wan2.1_audio"
)
class
WanAudioRunner
(
WanRunner
):
def
__init__
(
self
,
config
):
...
...
@@ -604,19 +633,42 @@ class WanAudioRunner(WanRunner):
ref_img
=
rearrange
(
ref_img
,
"H W C -> 1 C H W"
)
ref_img
=
ref_img
[:,
:
3
]
# Resize and crop image
cond_frms
,
tgt_h
,
tgt_w
=
adaptive_resize
(
ref_img
)
config
.
tgt_h
=
tgt_h
config
.
tgt_w
=
tgt_w
clip_encoder_out
=
self
.
image_encoder
.
encode
(
cond_frms
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
lat_h
,
lat_w
=
tgt_h
//
8
,
tgt_w
//
8
config
.
lat_h
=
lat_h
config
.
lat_w
=
lat_w
vae_encode_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
if
isinstance
(
vae_encode_out
,
list
):
vae_encode_out
=
torch
.
stack
(
vae_encode_out
,
dim
=
0
).
to
(
torch
.
bfloat16
)
if
config
.
get
(
"adaptive_resize"
,
False
):
# Use adaptive_resize to modify aspect ratio
cond_frms
,
tgt_h
,
tgt_w
=
adaptive_resize
(
ref_img
)
config
.
tgt_h
=
tgt_h
config
.
tgt_w
=
tgt_w
clip_encoder_out
=
self
.
image_encoder
.
encode
(
cond_frms
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
lat_h
,
lat_w
=
tgt_h
//
8
,
tgt_w
//
8
config
.
lat_h
=
lat_h
config
.
lat_w
=
lat_w
vae_encode_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
if
isinstance
(
vae_encode_out
,
list
):
vae_encode_out
=
torch
.
stack
(
vae_encode_out
,
dim
=
0
).
to
(
torch
.
bfloat16
)
else
:
h
,
w
=
ref_img
.
shape
[
2
:]
aspect_ratio
=
h
/
w
max_area
=
config
.
target_height
*
config
.
target_width
lat_h
=
round
(
np
.
sqrt
(
max_area
*
aspect_ratio
)
//
config
.
vae_stride
[
1
]
//
config
.
patch_size
[
1
]
*
config
.
patch_size
[
1
])
lat_w
=
round
(
np
.
sqrt
(
max_area
/
aspect_ratio
)
//
config
.
vae_stride
[
2
]
//
config
.
patch_size
[
2
]
*
config
.
patch_size
[
2
])
lat_h
,
lat_w
=
optimize_latent_size_with_sp
(
lat_h
,
lat_w
,
1
,
config
.
patch_size
[
1
:])
config
.
lat_h
,
config
.
lat_w
=
lat_h
,
lat_w
config
.
tgt_h
=
lat_h
*
config
.
vae_stride
[
1
]
config
.
tgt_w
=
lat_w
*
config
.
vae_stride
[
2
]
# Resize image to target size
cond_frms
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
config
.
tgt_h
,
config
.
tgt_w
),
mode
=
"bicubic"
)
clip_encoder_out
=
self
.
image_encoder
.
encode
(
cond_frms
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
# Prepare for VAE encoding
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
vae_encode_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
if
isinstance
(
vae_encode_out
,
list
):
vae_encode_out
=
torch
.
stack
(
vae_encode_out
,
dim
=
0
).
to
(
torch
.
bfloat16
)
return
vae_encode_out
,
clip_encoder_out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment