Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
c05ebad7
Unverified
Commit
c05ebad7
authored
Dec 03, 2025
by
Musisoul
Committed by
GitHub
Dec 03, 2025
Browse files
Support hunyuan parallel vae (#560)
parent
58f84489
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
169 additions
and
4 deletions
+169
-4
lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
...v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
+69
-2
lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py
...s/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py
+100
-2
No files found.
lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
View file @
c05ebad7
...
...
@@ -128,10 +128,75 @@ class HunyuanVideo15Runner(DefaultRunner):
target_height
//
self
.
config
[
"vae_stride"
][
1
],
target_width
//
self
.
config
[
"vae_stride"
][
2
],
]
self
.
target_height
=
target_height
self
.
target_width
=
target_width
ori_latent_h
,
ori_latent_w
=
latent_shape
[
2
],
latent_shape
[
3
]
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
latent_h
,
latent_w
,
world_size_h
,
world_size_w
=
self
.
_adjust_latent_for_grid_splitting
(
ori_latent_h
,
ori_latent_w
,
dist
.
get_world_size
())
latent_shape
[
2
],
latent_shape
[
3
]
=
latent_h
,
latent_w
logger
.
info
(
f
"ori latent:
{
ori_latent_h
}
x
{
ori_latent_w
}
, adjust_latent:
{
latent_h
}
x
{
latent_w
}
, grid:
{
world_size_h
}
x
{
world_size_w
}
"
)
else
:
latent_shape
[
2
],
latent_shape
[
3
]
=
ori_latent_h
,
ori_latent_w
world_size_h
,
world_size_w
=
None
,
None
self
.
vae_decoder
.
world_size_h
=
world_size_h
self
.
vae_decoder
.
world_size_w
=
world_size_w
self
.
target_height
=
latent_shape
[
2
]
*
self
.
config
[
"vae_stride"
][
1
]
self
.
target_width
=
latent_shape
[
3
]
*
self
.
config
[
"vae_stride"
][
2
]
return
latent_shape
def
_adjust_latent_for_grid_splitting
(
self
,
latent_h
,
latent_w
,
world_size
):
"""
Adjust latent dimensions for optimal 2D grid splitting.
Prefers balanced grids like 2x4 or 4x2 over 1x8 or 8x1.
"""
world_size_h
,
world_size_w
=
1
,
1
if
world_size
<=
1
:
return
latent_h
,
latent_w
,
world_size_h
,
world_size_w
# Define priority grids for different world sizes
priority_grids
=
[]
if
world_size
==
8
:
# For 8 cards, prefer 2x4 and 4x2 over 1x8 and 8x1
priority_grids
=
[(
2
,
4
),
(
4
,
2
),
(
1
,
8
),
(
8
,
1
)]
elif
world_size
==
4
:
priority_grids
=
[(
2
,
2
),
(
1
,
4
),
(
4
,
1
)]
elif
world_size
==
2
:
priority_grids
=
[(
1
,
2
),
(
2
,
1
)]
else
:
# For other sizes, try factor pairs
for
h
in
range
(
1
,
int
(
np
.
sqrt
(
world_size
))
+
1
):
if
world_size
%
h
==
0
:
w
=
world_size
//
h
priority_grids
.
append
((
h
,
w
))
# Try priority grids first
for
world_size_h
,
world_size_w
in
priority_grids
:
if
latent_h
%
world_size_h
==
0
and
latent_w
%
world_size_w
==
0
:
return
latent_h
,
latent_w
,
world_size_h
,
world_size_w
# If no perfect fit, find minimal padding solution
best_grid
=
(
1
,
world_size
)
# fallback
min_total_padding
=
float
(
"inf"
)
for
world_size_h
,
world_size_w
in
priority_grids
:
# Calculate required padding
pad_h
=
(
world_size_h
-
(
latent_h
%
world_size_h
))
%
world_size_h
pad_w
=
(
world_size_w
-
(
latent_w
%
world_size_w
))
%
world_size_w
total_padding
=
pad_h
+
pad_w
# Prefer grids with minimal total padding
if
total_padding
<
min_total_padding
:
min_total_padding
=
total_padding
best_grid
=
(
world_size_h
,
world_size_w
)
# Apply padding
world_size_h
,
world_size_w
=
best_grid
pad_h
=
(
world_size_h
-
(
latent_h
%
world_size_h
))
%
world_size_h
pad_w
=
(
world_size_w
-
(
latent_w
%
world_size_w
))
%
world_size_w
return
latent_h
+
pad_h
,
latent_w
+
pad_w
,
world_size_h
,
world_size_w
def
get_sr_latent_shape_with_target_hw
(
self
):
SizeMap
=
{
"480p"
:
640
,
...
...
@@ -254,6 +319,7 @@ class HunyuanVideo15Runner(DefaultRunner):
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"dtype"
:
GET_DTYPE
(),
"parallel"
:
self
.
config
[
"parallel"
],
}
if
self
.
config
[
"task"
]
not
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"vace"
,
"s2v"
]:
return
None
...
...
@@ -273,6 +339,7 @@ class HunyuanVideo15Runner(DefaultRunner):
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"dtype"
:
GET_DTYPE
(),
"parallel"
:
self
.
config
[
"parallel"
],
}
if
self
.
config
.
get
(
"use_tae"
,
False
):
tae_path
=
self
.
config
[
"tae_path"
]
...
...
lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py
View file @
c05ebad7
...
...
@@ -5,6 +5,7 @@ from typing import Optional, Tuple, Union
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.models.autoencoders.vae
import
BaseOutput
,
DiagonalGaussianDistribution
...
...
@@ -787,9 +788,11 @@ class AutoencoderKLConv3D(ModelMixin, ConfigMixin):
class
HunyuanVideo15VAE
:
def
__init__
(
self
,
checkpoint_path
=
None
,
dtype
=
torch
.
float16
,
device
=
"cuda"
,
cpu_offload
=
False
):
def
__init__
(
self
,
checkpoint_path
=
None
,
dtype
=
torch
.
float16
,
device
=
"cuda"
,
cpu_offload
=
False
,
parallel
=
False
):
self
.
vae
=
AutoencoderKLConv3D
.
from_pretrained
(
os
.
path
.
join
(
checkpoint_path
,
"vae"
)).
to
(
dtype
).
to
(
device
)
self
.
vae
.
cpu_offload
=
cpu_offload
self
.
parallel
=
parallel
self
.
world_size_h
,
self
.
world_size_w
=
None
,
None
@
torch
.
no_grad
()
def
encode
(
self
,
x
):
...
...
@@ -800,10 +803,105 @@ class HunyuanVideo15VAE:
z
=
z
/
self
.
vae
.
config
.
scaling_factor
self
.
vae
.
enable_tiling
()
video_frames
=
self
.
vae
.
decode
(
z
,
return_dict
=
False
)[
0
]
if
self
.
parallel
and
self
.
world_size_h
is
not
None
and
self
.
world_size_w
is
not
None
:
video_frames
=
self
.
decode_dist_2d
(
z
,
self
.
world_size_h
,
self
.
world_size_w
)
self
.
world_size_h
,
self
.
world_size_w
=
None
,
None
else
:
video_frames
=
self
.
vae
.
decode
(
z
,
return_dict
=
False
)[
0
]
self
.
vae
.
disable_tiling
()
return
video_frames
@
torch
.
no_grad
()
def
decode_dist_2d
(
self
,
z
,
world_size_h
,
world_size_w
):
cur_rank
=
dist
.
get_rank
()
cur_rank_h
=
cur_rank
//
world_size_w
cur_rank_w
=
cur_rank
%
world_size_w
total_h
=
z
.
shape
[
3
]
total_w
=
z
.
shape
[
4
]
chunk_h
=
total_h
//
world_size_h
chunk_w
=
total_w
//
world_size_w
padding_size
=
1
# Calculate H dimension slice
if
cur_rank_h
==
0
:
h_start
=
0
h_end
=
chunk_h
+
2
*
padding_size
elif
cur_rank_h
==
world_size_h
-
1
:
h_start
=
total_h
-
(
chunk_h
+
2
*
padding_size
)
h_end
=
total_h
else
:
h_start
=
cur_rank_h
*
chunk_h
-
padding_size
h_end
=
(
cur_rank_h
+
1
)
*
chunk_h
+
padding_size
# Calculate W dimension slice
if
cur_rank_w
==
0
:
w_start
=
0
w_end
=
chunk_w
+
2
*
padding_size
elif
cur_rank_w
==
world_size_w
-
1
:
w_start
=
total_w
-
(
chunk_w
+
2
*
padding_size
)
w_end
=
total_w
else
:
w_start
=
cur_rank_w
*
chunk_w
-
padding_size
w_end
=
(
cur_rank_w
+
1
)
*
chunk_w
+
padding_size
# Extract the latent chunk for this process
zs_chunk
=
z
[:,
:,
:,
h_start
:
h_end
,
w_start
:
w_end
].
contiguous
()
# Decode the chunk
images_chunk
=
self
.
vae
.
decode
(
zs_chunk
,
return_dict
=
False
)[
0
]
# Remove padding from decoded chunk
spatial_ratio
=
16
if
cur_rank_h
==
0
:
decoded_h_start
=
0
decoded_h_end
=
chunk_h
*
spatial_ratio
elif
cur_rank_h
==
world_size_h
-
1
:
decoded_h_start
=
images_chunk
.
shape
[
3
]
-
chunk_h
*
spatial_ratio
decoded_h_end
=
images_chunk
.
shape
[
3
]
else
:
decoded_h_start
=
padding_size
*
spatial_ratio
decoded_h_end
=
images_chunk
.
shape
[
3
]
-
padding_size
*
spatial_ratio
if
cur_rank_w
==
0
:
decoded_w_start
=
0
decoded_w_end
=
chunk_w
*
spatial_ratio
elif
cur_rank_w
==
world_size_w
-
1
:
decoded_w_start
=
images_chunk
.
shape
[
4
]
-
chunk_w
*
spatial_ratio
decoded_w_end
=
images_chunk
.
shape
[
4
]
else
:
decoded_w_start
=
padding_size
*
spatial_ratio
decoded_w_end
=
images_chunk
.
shape
[
4
]
-
padding_size
*
spatial_ratio
images_chunk
=
images_chunk
[:,
:,
:,
decoded_h_start
:
decoded_h_end
,
decoded_w_start
:
decoded_w_end
].
contiguous
()
# Gather all chunks
total_processes
=
world_size_h
*
world_size_w
full_images
=
[
torch
.
empty_like
(
images_chunk
)
for
_
in
range
(
total_processes
)]
dist
.
all_gather
(
full_images
,
images_chunk
)
self
.
device_synchronize
()
# Reconstruct the full image tensor
image_rows
=
[]
for
h_idx
in
range
(
world_size_h
):
image_cols
=
[]
for
w_idx
in
range
(
world_size_w
):
process_idx
=
h_idx
*
world_size_w
+
w_idx
image_cols
.
append
(
full_images
[
process_idx
])
image_rows
.
append
(
torch
.
cat
(
image_cols
,
dim
=
4
))
images
=
torch
.
cat
(
image_rows
,
dim
=
3
)
return
images
def
device_synchronize
(
self
,
):
torch_device_module
.
synchronize
()
if
__name__
==
"__main__"
:
vae
=
HunyuanVideo15VAE
(
checkpoint_path
=
"/data/nvme1/yongyang/models/HunyuanVideo-1.5/ckpts/hunyuanvideo-1.5"
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment