Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
5b902afb
Unverified
Commit
5b902afb
authored
Dec 10, 2025
by
LiangLiu
Committed by
GitHub
Dec 10, 2025
Browse files
Stream vae (#594)
parent
97549ed0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
11 deletions
+41
-11
lightx2v/deploy/common/va_controller.py
lightx2v/deploy/common/va_controller.py
+4
-0
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+37
-11
No files found.
lightx2v/deploy/common/va_controller.py
View file @
5b902afb
...
@@ -138,11 +138,15 @@ class VAController:
...
@@ -138,11 +138,15 @@ class VAController:
dist
.
barrier
()
dist
.
barrier
()
def
next_control
(
self
):
def
next_control
(
self
):
from
lightx2v.deploy.common.va_reader_omni
import
OmniVAReader
if
isinstance
(
self
.
reader
,
OmniVAReader
):
if
isinstance
(
self
.
reader
,
OmniVAReader
):
return
self
.
omni_reader_next_control
()
return
self
.
omni_reader_next_control
()
return
NextControl
(
action
=
"fetch"
)
return
NextControl
(
action
=
"fetch"
)
def
before_control
(
self
):
def
before_control
(
self
):
from
lightx2v.deploy.common.va_reader_omni
import
OmniVAReader
if
isinstance
(
self
.
reader
,
OmniVAReader
):
if
isinstance
(
self
.
reader
,
OmniVAReader
):
self
.
len_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
self
.
len_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
self
.
flag_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
self
.
flag_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
5b902afb
...
@@ -8,7 +8,6 @@ from typing import Dict, List, Optional, Tuple, Union
...
@@ -8,7 +8,6 @@ from typing import Dict, List, Optional, Tuple, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torchaudio
as
ta
import
torchaudio
as
ta
import
torchvision.transforms.functional
as
TF
import
torchvision.transforms.functional
as
TF
...
@@ -711,13 +710,37 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -711,13 +710,37 @@ class WanAudioRunner(WanRunner): # type:ignore
del
video_seg
,
audio_seg
del
video_seg
,
audio_seg
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
get_rank_and_world_size
(
self
):
@
ProfilingContext4DebugL1
(
rank
=
0
"End run segment stream"
,
world_size
=
1
recorder_mode
=
GET_RECORDER_MODE
(),
if
dist
.
is_initialized
():
metrics_func
=
monitor_cli
.
lightx2v_run_end_run_segment_duration
,
rank
=
dist
.
get_rank
()
metrics_labels
=
[
"WanAudioRunner"
],
world_size
=
dist
.
get_world_size
()
)
return
rank
,
world_size
def
end_run_segment_stream
(
self
,
latents
):
valid_length
=
self
.
segment
.
end_frame
-
self
.
segment
.
start_frame
frame_segments
=
[]
frame_idx
=
0
# frame_segment: 1*C*1*H*W, 1*C*4*H*W, 1*C*4*H*W, ...
for
origin_seg
in
self
.
run_vae_decoder_stream
(
latents
):
origin_seg
=
torch
.
clamp
(
origin_seg
,
-
1
,
1
).
to
(
torch
.
float
)
valid_T
=
min
(
valid_length
-
frame_idx
,
origin_seg
.
shape
[
2
])
video_seg
=
vae_to_comfyui_image_inplace
(
origin_seg
[:,
:,
:
valid_T
].
cpu
())
audio_start
=
frame_idx
*
self
.
_audio_processor
.
audio_frame_rate
audio_end
=
(
frame_idx
+
valid_T
)
*
self
.
_audio_processor
.
audio_frame_rate
audio_seg
=
self
.
segment
.
audio_array
[:,
audio_start
:
audio_end
].
sum
(
dim
=
0
)
if
self
.
va_controller
.
recorder
is
not
None
:
self
.
va_controller
.
pub_livestream
(
video_seg
,
audio_seg
,
origin_seg
[:,
:,
:
valid_T
])
frame_segments
.
append
(
origin_seg
)
frame_idx
+=
valid_T
del
video_seg
,
audio_seg
# Update prev_video for next iteration
self
.
prev_video
=
torch
.
cat
(
frame_segments
,
dim
=
2
)
torch
.
cuda
.
empty_cache
()
def
run_main
(
self
):
def
run_main
(
self
):
try
:
try
:
...
@@ -764,9 +787,12 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -764,9 +787,12 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
check_stop
()
self
.
check_stop
()
latents
=
self
.
run_segment
(
segment_idx
)
latents
=
self
.
run_segment
(
segment_idx
)
self
.
check_stop
()
self
.
check_stop
()
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
)
if
self
.
config
.
get
(
"use_stream_vae"
,
False
):
self
.
check_stop
()
self
.
end_run_segment_stream
(
latents
)
self
.
end_run_segment
(
segment_idx
)
else
:
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
)
self
.
check_stop
()
self
.
end_run_segment
(
segment_idx
)
segment_idx
+=
1
segment_idx
+=
1
fail_count
=
0
fail_count
=
0
except
Exception
as
e
:
except
Exception
as
e
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment