Unverified Commit 5b902afb authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

Stream vae (#594)

parent 97549ed0
......@@ -138,11 +138,15 @@ class VAController:
dist.barrier()
def next_control(self):
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
if isinstance(self.reader, OmniVAReader):
return self.omni_reader_next_control()
return NextControl(action="fetch")
def before_control(self):
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
if isinstance(self.reader, OmniVAReader):
self.len_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
self.flag_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
......
......@@ -8,7 +8,6 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchaudio as ta
import torchvision.transforms.functional as TF
......@@ -711,13 +710,37 @@ class WanAudioRunner(WanRunner): # type:ignore
del video_seg, audio_seg
torch.cuda.empty_cache()
def get_rank_and_world_size(self):
rank = 0
world_size = 1
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
return rank, world_size
@ProfilingContext4DebugL1(
"End run segment stream",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_end_run_segment_duration,
metrics_labels=["WanAudioRunner"],
)
def end_run_segment_stream(self, latents):
valid_length = self.segment.end_frame - self.segment.start_frame
frame_segments = []
frame_idx = 0
# frame_segment: 1*C*1*H*W, 1*C*4*H*W, 1*C*4*H*W, ...
for origin_seg in self.run_vae_decoder_stream(latents):
origin_seg = torch.clamp(origin_seg, -1, 1).to(torch.float)
valid_T = min(valid_length - frame_idx, origin_seg.shape[2])
video_seg = vae_to_comfyui_image_inplace(origin_seg[:, :, :valid_T].cpu())
audio_start = frame_idx * self._audio_processor.audio_frame_rate
audio_end = (frame_idx + valid_T) * self._audio_processor.audio_frame_rate
audio_seg = self.segment.audio_array[:, audio_start:audio_end].sum(dim=0)
if self.va_controller.recorder is not None:
self.va_controller.pub_livestream(video_seg, audio_seg, origin_seg[:, :, :valid_T])
frame_segments.append(origin_seg)
frame_idx += valid_T
del video_seg, audio_seg
# Update prev_video for next iteration
self.prev_video = torch.cat(frame_segments, dim=2)
torch.cuda.empty_cache()
def run_main(self):
try:
......@@ -764,9 +787,12 @@ class WanAudioRunner(WanRunner): # type:ignore
self.check_stop()
latents = self.run_segment(segment_idx)
self.check_stop()
self.gen_video = self.run_vae_decoder(latents)
self.check_stop()
self.end_run_segment(segment_idx)
if self.config.get("use_stream_vae", False):
self.end_run_segment_stream(latents)
else:
self.gen_video = self.run_vae_decoder(latents)
self.check_stop()
self.end_run_segment(segment_idx)
segment_idx += 1
fail_count = 0
except Exception as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment