Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
5546f759
Unverified
Commit
5546f759
authored
Dec 09, 2025
by
Musisoul
Committed by
GitHub
Dec 09, 2025
Browse files
[feat] stream vae (#582)
parent
0ad8ada3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
142 additions
and
1 deletion
+142
-1
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+21
-1
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+121
-0
No files found.
lightx2v/models/runners/default_runner.py
100755 → 100644
View file @
5546f759
...
...
@@ -319,7 +319,14 @@ class DefaultRunner(BaseRunner):
# 2. main inference loop
latents
=
self
.
run_segment
(
segment_idx
)
# 3. vae decoder
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
)
if
self
.
config
.
get
(
"use_stream_vae"
,
False
):
frames
=
[]
for
frame_segment
in
self
.
run_vae_decoder_stream
(
latents
):
frames
.
append
(
frame_segment
)
logger
.
info
(
f
"frame sagment:
{
len
(
frames
)
}
done"
)
self
.
gen_video
=
torch
.
cat
(
frames
,
dim
=
2
)
else
:
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
)
# 4. default do nothing
self
.
end_run_segment
(
segment_idx
)
gen_video_final
=
self
.
process_images_after_vae_decoder
()
...
...
@@ -337,6 +344,19 @@ class DefaultRunner(BaseRunner):
gc
.
collect
()
return
images
@
ProfilingContext4DebugL1
(
"Run VAE Decoder Stream"
,
recorder_mode
=
GET_RECORDER_MODE
(),
metrics_func
=
monitor_cli
.
lightx2v_run_vae_decode_duration
,
metrics_labels
=
[
"DefaultRunner"
])
def
run_vae_decoder_stream
(
self
,
latents
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_decoder
=
self
.
load_vae_decoder
()
for
frame_segment
in
self
.
vae_decoder
.
decode_stream
(
latents
.
to
(
GET_DTYPE
())):
yield
frame_segment
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_decoder
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
post_prompt_enhancer
(
self
):
while
True
:
for
url
in
self
.
config
[
"sub_servers"
][
"prompt_enhancer"
]:
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
5546f759
...
...
@@ -724,6 +724,25 @@ class WanVAE_(nn.Module):
self
.
clear_cache
()
return
out
def
decode_stream
(
self
,
z
,
scale
):
self
.
clear_cache
()
# z: [b,c,t,h,w]
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
z
=
z
/
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
+
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
else
:
z
=
z
/
scale
[
1
]
+
scale
[
0
]
iter_
=
z
.
shape
[
2
]
x
=
self
.
conv2
(
z
)
for
i
in
range
(
iter_
):
self
.
_conv_idx
=
[
0
]
out
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:],
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
,
)
yield
out
def
cached_decode
(
self
,
z
,
scale
):
# z: [b,c,t,h,w]
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
...
...
@@ -1291,6 +1310,87 @@ class WanVAE:
return
images
def
decode_dist_2d_stream
(
self
,
zs
,
world_size_h
,
world_size_w
,
cur_rank_h
,
cur_rank_w
):
total_h
=
zs
.
shape
[
2
]
total_w
=
zs
.
shape
[
3
]
chunk_h
=
total_h
//
world_size_h
chunk_w
=
total_w
//
world_size_w
padding_size
=
1
# Calculate H dimension slice
if
cur_rank_h
==
0
:
h_start
=
0
h_end
=
chunk_h
+
2
*
padding_size
elif
cur_rank_h
==
world_size_h
-
1
:
h_start
=
total_h
-
(
chunk_h
+
2
*
padding_size
)
h_end
=
total_h
else
:
h_start
=
cur_rank_h
*
chunk_h
-
padding_size
h_end
=
(
cur_rank_h
+
1
)
*
chunk_h
+
padding_size
# Calculate W dimension slice
if
cur_rank_w
==
0
:
w_start
=
0
w_end
=
chunk_w
+
2
*
padding_size
elif
cur_rank_w
==
world_size_w
-
1
:
w_start
=
total_w
-
(
chunk_w
+
2
*
padding_size
)
w_end
=
total_w
else
:
w_start
=
cur_rank_w
*
chunk_w
-
padding_size
w_end
=
(
cur_rank_w
+
1
)
*
chunk_w
+
padding_size
# Extract the latent chunk for this process
zs_chunk
=
zs
[:,
:,
h_start
:
h_end
,
w_start
:
w_end
].
contiguous
()
for
image
in
self
.
model
.
decode_stream
(
zs_chunk
.
unsqueeze
(
0
),
self
.
scale
):
images_chunk
=
image
.
clamp_
(
-
1
,
1
)
# Remove padding from decoded chunk
spatial_ratio
=
8
if
cur_rank_h
==
0
:
decoded_h_start
=
0
decoded_h_end
=
chunk_h
*
spatial_ratio
elif
cur_rank_h
==
world_size_h
-
1
:
decoded_h_start
=
images_chunk
.
shape
[
3
]
-
chunk_h
*
spatial_ratio
decoded_h_end
=
images_chunk
.
shape
[
3
]
else
:
decoded_h_start
=
padding_size
*
spatial_ratio
decoded_h_end
=
images_chunk
.
shape
[
3
]
-
padding_size
*
spatial_ratio
if
cur_rank_w
==
0
:
decoded_w_start
=
0
decoded_w_end
=
chunk_w
*
spatial_ratio
elif
cur_rank_w
==
world_size_w
-
1
:
decoded_w_start
=
images_chunk
.
shape
[
4
]
-
chunk_w
*
spatial_ratio
decoded_w_end
=
images_chunk
.
shape
[
4
]
else
:
decoded_w_start
=
padding_size
*
spatial_ratio
decoded_w_end
=
images_chunk
.
shape
[
4
]
-
padding_size
*
spatial_ratio
images_chunk
=
images_chunk
[:,
:,
:,
decoded_h_start
:
decoded_h_end
,
decoded_w_start
:
decoded_w_end
].
contiguous
()
# Gather all chunks
total_processes
=
world_size_h
*
world_size_w
full_images
=
[
torch
.
empty_like
(
images_chunk
)
for
_
in
range
(
total_processes
)]
dist
.
all_gather
(
full_images
,
images_chunk
)
self
.
device_synchronize
()
# Reconstruct the full image tensor
image_rows
=
[]
for
h_idx
in
range
(
world_size_h
):
image_cols
=
[]
for
w_idx
in
range
(
world_size_w
):
process_idx
=
h_idx
*
world_size_w
+
w_idx
image_cols
.
append
(
full_images
[
process_idx
])
image_rows
.
append
(
torch
.
cat
(
image_cols
,
dim
=
4
))
images
=
torch
.
cat
(
image_rows
,
dim
=
3
)
yield
images
def
decode
(
self
,
zs
):
if
self
.
cpu_offload
:
self
.
to_cuda
()
...
...
@@ -1324,6 +1424,27 @@ class WanVAE:
return
images
def
decode_stream
(
self
,
zs
):
if
self
.
cpu_offload
:
self
.
to_cuda
()
if
self
.
parallel
:
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
latent_height
,
latent_width
=
zs
.
shape
[
2
],
zs
.
shape
[
3
]
world_size_h
,
world_size_w
=
self
.
_calculate_2d_grid
(
latent_height
,
latent_width
,
world_size
)
cur_rank_h
=
cur_rank
//
world_size_w
cur_rank_w
=
cur_rank
%
world_size_w
for
images
in
self
.
decode_dist_2d_stream
(
zs
,
world_size_h
,
world_size_w
,
cur_rank_h
,
cur_rank_w
):
yield
images
else
:
for
image
in
self
.
model
.
decode_stream
(
zs
.
unsqueeze
(
0
),
self
.
scale
):
yield
image
.
clamp_
(
-
1
,
1
)
if
self
.
cpu_offload
:
self
.
to_cpu
()
def
encode_video
(
self
,
vid
):
return
self
.
model
.
encode_video
(
vid
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment