Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
aa627f77
Unverified
Commit
aa627f77
authored
Dec 03, 2025
by
Musisoul
Committed by
GitHub
Dec 03, 2025
Browse files
Fix vae parallel bug (#548)
parent
a1a1a8c0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
9 deletions
+73
-9
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+70
-7
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+3
-2
No files found.
lightx2v/models/runners/wan/wan_runner.py
View file @
aa627f77
...
@@ -271,6 +271,58 @@ class WanRunner(DefaultRunner):
...
@@ -271,6 +271,58 @@ class WanRunner(DefaultRunner):
gc
.
collect
()
gc
.
collect
()
return
clip_encoder_out
return
clip_encoder_out
def
_adjust_latent_for_grid_splitting
(
self
,
latent_h
,
latent_w
,
world_size
):
"""
Adjust latent dimensions for optimal 2D grid splitting.
Prefers balanced grids like 2x4 or 4x2 over 1x8 or 8x1.
"""
world_size_h
,
world_size_w
=
1
,
1
if
world_size
<=
1
:
return
latent_h
,
latent_w
,
world_size_h
,
world_size_w
# Define priority grids for different world sizes
priority_grids
=
[]
if
world_size
==
8
:
# For 8 cards, prefer 2x4 and 4x2 over 1x8 and 8x1
priority_grids
=
[(
2
,
4
),
(
4
,
2
),
(
1
,
8
),
(
8
,
1
)]
elif
world_size
==
4
:
priority_grids
=
[(
2
,
2
),
(
1
,
4
),
(
4
,
1
)]
elif
world_size
==
2
:
priority_grids
=
[(
1
,
2
),
(
2
,
1
)]
else
:
# For other sizes, try factor pairs
for
h
in
range
(
1
,
int
(
np
.
sqrt
(
world_size
))
+
1
):
if
world_size
%
h
==
0
:
w
=
world_size
//
h
priority_grids
.
append
((
h
,
w
))
# Try priority grids first
for
world_size_h
,
world_size_w
in
priority_grids
:
if
latent_h
%
world_size_h
==
0
and
latent_w
%
world_size_w
==
0
:
return
latent_h
,
latent_w
,
world_size_h
,
world_size_w
# If no perfect fit, find minimal padding solution
best_grid
=
(
1
,
world_size
)
# fallback
min_total_padding
=
float
(
"inf"
)
for
world_size_h
,
world_size_w
in
priority_grids
:
# Calculate required padding
pad_h
=
(
world_size_h
-
(
latent_h
%
world_size_h
))
%
world_size_h
pad_w
=
(
world_size_w
-
(
latent_w
%
world_size_w
))
%
world_size_w
total_padding
=
pad_h
+
pad_w
# Prefer grids with minimal total padding
if
total_padding
<
min_total_padding
:
min_total_padding
=
total_padding
best_grid
=
(
world_size_h
,
world_size_w
)
# Apply padding
world_size_h
,
world_size_w
=
best_grid
pad_h
=
(
world_size_h
-
(
latent_h
%
world_size_h
))
%
world_size_h
pad_w
=
(
world_size_w
-
(
latent_w
%
world_size_w
))
%
world_size_w
return
latent_h
+
pad_h
,
latent_w
+
pad_w
,
world_size_h
,
world_size_w
@
ProfilingContext4DebugL1
(
@
ProfilingContext4DebugL1
(
"Run VAE Encoder"
,
"Run VAE Encoder"
,
recorder_mode
=
GET_RECORDER_MODE
(),
recorder_mode
=
GET_RECORDER_MODE
(),
...
@@ -281,8 +333,19 @@ class WanRunner(DefaultRunner):
...
@@ -281,8 +333,19 @@ class WanRunner(DefaultRunner):
h
,
w
=
first_frame
.
shape
[
2
:]
h
,
w
=
first_frame
.
shape
[
2
:]
aspect_ratio
=
h
/
w
aspect_ratio
=
h
/
w
max_area
=
self
.
config
[
"target_height"
]
*
self
.
config
[
"target_width"
]
max_area
=
self
.
config
[
"target_height"
]
*
self
.
config
[
"target_width"
]
latent_h
=
round
(
np
.
sqrt
(
max_area
*
aspect_ratio
)
//
self
.
config
[
"vae_stride"
][
1
]
//
self
.
config
[
"patch_size"
][
1
]
*
self
.
config
[
"patch_size"
][
1
])
latent_w
=
round
(
np
.
sqrt
(
max_area
/
aspect_ratio
)
//
self
.
config
[
"vae_stride"
][
2
]
//
self
.
config
[
"patch_size"
][
2
]
*
self
.
config
[
"patch_size"
][
2
])
# Calculate initial latent dimensions
ori_latent_h
=
round
(
np
.
sqrt
(
max_area
*
aspect_ratio
)
//
self
.
config
[
"vae_stride"
][
1
]
//
self
.
config
[
"patch_size"
][
1
]
*
self
.
config
[
"patch_size"
][
1
])
ori_latent_w
=
round
(
np
.
sqrt
(
max_area
/
aspect_ratio
)
//
self
.
config
[
"vae_stride"
][
2
]
//
self
.
config
[
"patch_size"
][
2
]
*
self
.
config
[
"patch_size"
][
2
])
# Adjust latent dimensions for optimal 2D grid splitting when using distributed processing
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
latent_h
,
latent_w
,
world_size_h
,
world_size_w
=
self
.
_adjust_latent_for_grid_splitting
(
ori_latent_h
,
ori_latent_w
,
dist
.
get_world_size
())
logger
.
info
(
f
"ori latent:
{
ori_latent_h
}
x
{
ori_latent_w
}
, adjust_latent:
{
latent_h
}
x
{
latent_w
}
, grid:
{
world_size_h
}
x
{
world_size_w
}
"
)
else
:
latent_h
,
latent_w
=
ori_latent_h
,
ori_latent_w
world_size_h
,
world_size_w
=
None
,
None
latent_shape
=
self
.
get_latent_shape_with_lat_hw
(
latent_h
,
latent_w
)
# Important: latent_shape is used to set the input_info
latent_shape
=
self
.
get_latent_shape_with_lat_hw
(
latent_h
,
latent_w
)
# Important: latent_shape is used to set the input_info
if
self
.
config
.
get
(
"changing_resolution"
,
False
):
if
self
.
config
.
get
(
"changing_resolution"
,
False
):
...
@@ -293,8 +356,8 @@ class WanRunner(DefaultRunner):
...
@@ -293,8 +356,8 @@ class WanRunner(DefaultRunner):
int
(
latent_h
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
latent_h
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
latent_w
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
latent_w
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
)
)
vae_encode_out_list
.
append
(
self
.
get_vae_encoder_output
(
first_frame
,
latent_h_tmp
,
latent_w_tmp
))
vae_encode_out_list
.
append
(
self
.
get_vae_encoder_output
(
first_frame
,
latent_h_tmp
,
latent_w_tmp
,
world_size_h
=
world_size_h
,
world_size_w
=
world_size_w
))
vae_encode_out_list
.
append
(
self
.
get_vae_encoder_output
(
first_frame
,
latent_h
,
latent_w
))
vae_encode_out_list
.
append
(
self
.
get_vae_encoder_output
(
first_frame
,
latent_h
,
latent_w
,
world_size_h
=
world_size_h
,
world_size_w
=
world_size_w
))
return
vae_encode_out_list
,
latent_shape
return
vae_encode_out_list
,
latent_shape
else
:
else
:
if
last_frame
is
not
None
:
if
last_frame
is
not
None
:
...
@@ -307,10 +370,10 @@ class WanRunner(DefaultRunner):
...
@@ -307,10 +370,10 @@ class WanRunner(DefaultRunner):
round
(
last_frame_size
[
1
]
*
last_frame_resize_ratio
),
round
(
last_frame_size
[
1
]
*
last_frame_resize_ratio
),
]
]
last_frame
=
TF
.
center_crop
(
last_frame
,
last_frame_size
)
last_frame
=
TF
.
center_crop
(
last_frame
,
last_frame_size
)
vae_encoder_out
=
self
.
get_vae_encoder_output
(
first_frame
,
latent_h
,
latent_w
,
last_frame
)
vae_encoder_out
=
self
.
get_vae_encoder_output
(
first_frame
,
latent_h
,
latent_w
,
last_frame
,
world_size_h
=
world_size_h
,
world_size_w
=
world_size_w
)
return
vae_encoder_out
,
latent_shape
return
vae_encoder_out
,
latent_shape
def
get_vae_encoder_output
(
self
,
first_frame
,
lat_h
,
lat_w
,
last_frame
=
None
):
def
get_vae_encoder_output
(
self
,
first_frame
,
lat_h
,
lat_w
,
last_frame
=
None
,
world_size_h
=
None
,
world_size_w
=
None
):
h
=
lat_h
*
self
.
config
[
"vae_stride"
][
1
]
h
=
lat_h
*
self
.
config
[
"vae_stride"
][
1
]
w
=
lat_w
*
self
.
config
[
"vae_stride"
][
2
]
w
=
lat_w
*
self
.
config
[
"vae_stride"
][
2
]
msk
=
torch
.
ones
(
msk
=
torch
.
ones
(
...
@@ -350,7 +413,7 @@ class WanRunner(DefaultRunner):
...
@@ -350,7 +413,7 @@ class WanRunner(DefaultRunner):
dim
=
1
,
dim
=
1
,
).
to
(
AI_DEVICE
)
).
to
(
AI_DEVICE
)
vae_encoder_out
=
self
.
vae_encoder
.
encode
(
vae_input
.
unsqueeze
(
0
).
to
(
GET_DTYPE
()))
vae_encoder_out
=
self
.
vae_encoder
.
encode
(
vae_input
.
unsqueeze
(
0
).
to
(
GET_DTYPE
())
,
world_size_h
=
world_size_h
,
world_size_w
=
world_size_w
)
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_encoder
del
self
.
vae_encoder
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
aa627f77
...
@@ -1119,7 +1119,7 @@ class WanVAE:
...
@@ -1119,7 +1119,7 @@ class WanVAE:
return
encoded
.
squeeze
(
0
)
return
encoded
.
squeeze
(
0
)
def
encode
(
self
,
video
):
def
encode
(
self
,
video
,
world_size_h
=
None
,
world_size_w
=
None
):
"""
"""
video: one video with shape [1, C, T, H, W].
video: one video with shape [1, C, T, H, W].
"""
"""
...
@@ -1132,7 +1132,8 @@ class WanVAE:
...
@@ -1132,7 +1132,8 @@ class WanVAE:
height
,
width
=
video
.
shape
[
3
],
video
.
shape
[
4
]
height
,
width
=
video
.
shape
[
3
],
video
.
shape
[
4
]
if
self
.
use_2d_split
:
if
self
.
use_2d_split
:
world_size_h
,
world_size_w
=
self
.
_calculate_2d_grid
(
height
//
8
,
width
//
8
,
world_size
)
if
world_size_h
is
None
or
world_size_w
is
None
:
world_size_h
,
world_size_w
=
self
.
_calculate_2d_grid
(
height
//
8
,
width
//
8
,
world_size
)
cur_rank_h
=
cur_rank
//
world_size_w
cur_rank_h
=
cur_rank
//
world_size_w
cur_rank_w
=
cur_rank
%
world_size_w
cur_rank_w
=
cur_rank
%
world_size_w
out
=
self
.
encode_dist_2d
(
video
,
world_size_h
,
world_size_w
,
cur_rank_h
,
cur_rank_w
)
out
=
self
.
encode_dist_2d
(
video
,
world_size_h
,
world_size_w
,
cur_rank_h
,
cur_rank_w
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment