Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
c05ebad7
"official/vision/modeling/video_classification_model.py" did not exist on "f7754fe5b8e8c10092d4155715ab2395580d4901"
Unverified
Commit
c05ebad7
authored
Dec 03, 2025
by
Musisoul
Committed by
GitHub
Dec 03, 2025
Browse files
Support hunyuan parallel vae (#560)
parent
58f84489
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
169 additions
and
4 deletions
+169
-4
lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
...v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
+69
-2
lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py
...s/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py
+100
-2
No files found.
lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
View file @
c05ebad7
...
...
@@ -128,10 +128,75 @@ class HunyuanVideo15Runner(DefaultRunner):
target_height
//
self
.
config
[
"vae_stride"
][
1
],
target_width
//
self
.
config
[
"vae_stride"
][
2
],
]
self
.
target_height
=
target_height
self
.
target_width
=
target_width
ori_latent_h
,
ori_latent_w
=
latent_shape
[
2
],
latent_shape
[
3
]
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
latent_h
,
latent_w
,
world_size_h
,
world_size_w
=
self
.
_adjust_latent_for_grid_splitting
(
ori_latent_h
,
ori_latent_w
,
dist
.
get_world_size
())
latent_shape
[
2
],
latent_shape
[
3
]
=
latent_h
,
latent_w
logger
.
info
(
f
"ori latent:
{
ori_latent_h
}
x
{
ori_latent_w
}
, adjust_latent:
{
latent_h
}
x
{
latent_w
}
, grid:
{
world_size_h
}
x
{
world_size_w
}
"
)
else
:
latent_shape
[
2
],
latent_shape
[
3
]
=
ori_latent_h
,
ori_latent_w
world_size_h
,
world_size_w
=
None
,
None
self
.
vae_decoder
.
world_size_h
=
world_size_h
self
.
vae_decoder
.
world_size_w
=
world_size_w
self
.
target_height
=
latent_shape
[
2
]
*
self
.
config
[
"vae_stride"
][
1
]
self
.
target_width
=
latent_shape
[
3
]
*
self
.
config
[
"vae_stride"
][
2
]
return
latent_shape
def
_adjust_latent_for_grid_splitting
(
self
,
latent_h
,
latent_w
,
world_size
):
"""
Adjust latent dimensions for optimal 2D grid splitting.
Prefers balanced grids like 2x4 or 4x2 over 1x8 or 8x1.
"""
world_size_h
,
world_size_w
=
1
,
1
if
world_size
<=
1
:
return
latent_h
,
latent_w
,
world_size_h
,
world_size_w
# Define priority grids for different world sizes
priority_grids
=
[]
if
world_size
==
8
:
# For 8 cards, prefer 2x4 and 4x2 over 1x8 and 8x1
priority_grids
=
[(
2
,
4
),
(
4
,
2
),
(
1
,
8
),
(
8
,
1
)]
elif
world_size
==
4
:
priority_grids
=
[(
2
,
2
),
(
1
,
4
),
(
4
,
1
)]
elif
world_size
==
2
:
priority_grids
=
[(
1
,
2
),
(
2
,
1
)]
else
:
# For other sizes, try factor pairs
for
h
in
range
(
1
,
int
(
np
.
sqrt
(
world_size
))
+
1
):
if
world_size
%
h
==
0
:
w
=
world_size
//
h
priority_grids
.
append
((
h
,
w
))
# Try priority grids first
for
world_size_h
,
world_size_w
in
priority_grids
:
if
latent_h
%
world_size_h
==
0
and
latent_w
%
world_size_w
==
0
:
return
latent_h
,
latent_w
,
world_size_h
,
world_size_w
# If no perfect fit, find minimal padding solution
best_grid
=
(
1
,
world_size
)
# fallback
min_total_padding
=
float
(
"inf"
)
for
world_size_h
,
world_size_w
in
priority_grids
:
# Calculate required padding
pad_h
=
(
world_size_h
-
(
latent_h
%
world_size_h
))
%
world_size_h
pad_w
=
(
world_size_w
-
(
latent_w
%
world_size_w
))
%
world_size_w
total_padding
=
pad_h
+
pad_w
# Prefer grids with minimal total padding
if
total_padding
<
min_total_padding
:
min_total_padding
=
total_padding
best_grid
=
(
world_size_h
,
world_size_w
)
# Apply padding
world_size_h
,
world_size_w
=
best_grid
pad_h
=
(
world_size_h
-
(
latent_h
%
world_size_h
))
%
world_size_h
pad_w
=
(
world_size_w
-
(
latent_w
%
world_size_w
))
%
world_size_w
return
latent_h
+
pad_h
,
latent_w
+
pad_w
,
world_size_h
,
world_size_w
def
get_sr_latent_shape_with_target_hw
(
self
):
SizeMap
=
{
"480p"
:
640
,
...
...
@@ -254,6 +319,7 @@ class HunyuanVideo15Runner(DefaultRunner):
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"dtype"
:
GET_DTYPE
(),
"parallel"
:
self
.
config
[
"parallel"
],
}
if
self
.
config
[
"task"
]
not
in
[
"i2v"
,
"flf2v"
,
"animate"
,
"vace"
,
"s2v"
]:
return
None
...
...
@@ -273,6 +339,7 @@ class HunyuanVideo15Runner(DefaultRunner):
"device"
:
vae_device
,
"cpu_offload"
:
vae_offload
,
"dtype"
:
GET_DTYPE
(),
"parallel"
:
self
.
config
[
"parallel"
],
}
if
self
.
config
.
get
(
"use_tae"
,
False
):
tae_path
=
self
.
config
[
"tae_path"
]
...
...
lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py
View file @
c05ebad7
...
...
@@ -5,6 +5,7 @@ from typing import Optional, Tuple, Union
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.models.autoencoders.vae
import
BaseOutput
,
DiagonalGaussianDistribution
...
...
@@ -787,9 +788,11 @@ class AutoencoderKLConv3D(ModelMixin, ConfigMixin):
class
HunyuanVideo15VAE
:
def
__init__
(
self
,
checkpoint_path
=
None
,
dtype
=
torch
.
float16
,
device
=
"cuda"
,
cpu_offload
=
False
):
def
__init__
(
self
,
checkpoint_path
=
None
,
dtype
=
torch
.
float16
,
device
=
"cuda"
,
cpu_offload
=
False
,
parallel
=
False
):
self
.
vae
=
AutoencoderKLConv3D
.
from_pretrained
(
os
.
path
.
join
(
checkpoint_path
,
"vae"
)).
to
(
dtype
).
to
(
device
)
self
.
vae
.
cpu_offload
=
cpu_offload
self
.
parallel
=
parallel
self
.
world_size_h
,
self
.
world_size_w
=
None
,
None
@
torch
.
no_grad
()
def
encode
(
self
,
x
):
...
...
@@ -800,10 +803,105 @@ class HunyuanVideo15VAE:
z
=
z
/
self
.
vae
.
config
.
scaling_factor
self
.
vae
.
enable_tiling
()
if
self
.
parallel
and
self
.
world_size_h
is
not
None
and
self
.
world_size_w
is
not
None
:
video_frames
=
self
.
decode_dist_2d
(
z
,
self
.
world_size_h
,
self
.
world_size_w
)
self
.
world_size_h
,
self
.
world_size_w
=
None
,
None
else
:
video_frames
=
self
.
vae
.
decode
(
z
,
return_dict
=
False
)[
0
]
self
.
vae
.
disable_tiling
()
return
video_frames
@
torch
.
no_grad
()
def
decode_dist_2d
(
self
,
z
,
world_size_h
,
world_size_w
):
cur_rank
=
dist
.
get_rank
()
cur_rank_h
=
cur_rank
//
world_size_w
cur_rank_w
=
cur_rank
%
world_size_w
total_h
=
z
.
shape
[
3
]
total_w
=
z
.
shape
[
4
]
chunk_h
=
total_h
//
world_size_h
chunk_w
=
total_w
//
world_size_w
padding_size
=
1
# Calculate H dimension slice
if
cur_rank_h
==
0
:
h_start
=
0
h_end
=
chunk_h
+
2
*
padding_size
elif
cur_rank_h
==
world_size_h
-
1
:
h_start
=
total_h
-
(
chunk_h
+
2
*
padding_size
)
h_end
=
total_h
else
:
h_start
=
cur_rank_h
*
chunk_h
-
padding_size
h_end
=
(
cur_rank_h
+
1
)
*
chunk_h
+
padding_size
# Calculate W dimension slice
if
cur_rank_w
==
0
:
w_start
=
0
w_end
=
chunk_w
+
2
*
padding_size
elif
cur_rank_w
==
world_size_w
-
1
:
w_start
=
total_w
-
(
chunk_w
+
2
*
padding_size
)
w_end
=
total_w
else
:
w_start
=
cur_rank_w
*
chunk_w
-
padding_size
w_end
=
(
cur_rank_w
+
1
)
*
chunk_w
+
padding_size
# Extract the latent chunk for this process
zs_chunk
=
z
[:,
:,
:,
h_start
:
h_end
,
w_start
:
w_end
].
contiguous
()
# Decode the chunk
images_chunk
=
self
.
vae
.
decode
(
zs_chunk
,
return_dict
=
False
)[
0
]
# Remove padding from decoded chunk
spatial_ratio
=
16
if
cur_rank_h
==
0
:
decoded_h_start
=
0
decoded_h_end
=
chunk_h
*
spatial_ratio
elif
cur_rank_h
==
world_size_h
-
1
:
decoded_h_start
=
images_chunk
.
shape
[
3
]
-
chunk_h
*
spatial_ratio
decoded_h_end
=
images_chunk
.
shape
[
3
]
else
:
decoded_h_start
=
padding_size
*
spatial_ratio
decoded_h_end
=
images_chunk
.
shape
[
3
]
-
padding_size
*
spatial_ratio
if
cur_rank_w
==
0
:
decoded_w_start
=
0
decoded_w_end
=
chunk_w
*
spatial_ratio
elif
cur_rank_w
==
world_size_w
-
1
:
decoded_w_start
=
images_chunk
.
shape
[
4
]
-
chunk_w
*
spatial_ratio
decoded_w_end
=
images_chunk
.
shape
[
4
]
else
:
decoded_w_start
=
padding_size
*
spatial_ratio
decoded_w_end
=
images_chunk
.
shape
[
4
]
-
padding_size
*
spatial_ratio
images_chunk
=
images_chunk
[:,
:,
:,
decoded_h_start
:
decoded_h_end
,
decoded_w_start
:
decoded_w_end
].
contiguous
()
# Gather all chunks
total_processes
=
world_size_h
*
world_size_w
full_images
=
[
torch
.
empty_like
(
images_chunk
)
for
_
in
range
(
total_processes
)]
dist
.
all_gather
(
full_images
,
images_chunk
)
self
.
device_synchronize
()
# Reconstruct the full image tensor
image_rows
=
[]
for
h_idx
in
range
(
world_size_h
):
image_cols
=
[]
for
w_idx
in
range
(
world_size_w
):
process_idx
=
h_idx
*
world_size_w
+
w_idx
image_cols
.
append
(
full_images
[
process_idx
])
image_rows
.
append
(
torch
.
cat
(
image_cols
,
dim
=
4
))
images
=
torch
.
cat
(
image_rows
,
dim
=
3
)
return
images
def
device_synchronize
(
self
,
):
torch_device_module
.
synchronize
()
if
__name__
==
"__main__"
:
vae
=
HunyuanVideo15VAE
(
checkpoint_path
=
"/data/nvme1/yongyang/models/HunyuanVideo-1.5/ckpts/hunyuanvideo-1.5"
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment