Unverified Commit 5546f759 authored by Musisoul's avatar Musisoul Committed by GitHub
Browse files

[feat] stream vae (#582)

parent 0ad8ada3
...@@ -319,7 +319,14 @@ class DefaultRunner(BaseRunner): ...@@ -319,7 +319,14 @@ class DefaultRunner(BaseRunner):
# 2. main inference loop # 2. main inference loop
latents = self.run_segment(segment_idx) latents = self.run_segment(segment_idx)
# 3. vae decoder # 3. vae decoder
self.gen_video = self.run_vae_decoder(latents) if self.config.get("use_stream_vae", False):
frames = []
for frame_segment in self.run_vae_decoder_stream(latents):
frames.append(frame_segment)
logger.info(f"frame sagment: {len(frames)} done")
self.gen_video = torch.cat(frames, dim=2)
else:
self.gen_video = self.run_vae_decoder(latents)
# 4. default do nothing # 4. default do nothing
self.end_run_segment(segment_idx) self.end_run_segment(segment_idx)
gen_video_final = self.process_images_after_vae_decoder() gen_video_final = self.process_images_after_vae_decoder()
...@@ -337,6 +344,19 @@ class DefaultRunner(BaseRunner): ...@@ -337,6 +344,19 @@ class DefaultRunner(BaseRunner):
gc.collect() gc.collect()
return images return images
@ProfilingContext4DebugL1("Run VAE Decoder Stream", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, metrics_labels=["DefaultRunner"])
def run_vae_decoder_stream(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder()
for frame_segment in self.vae_decoder.decode_stream(latents.to(GET_DTYPE())):
yield frame_segment
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder
torch.cuda.empty_cache()
gc.collect()
def post_prompt_enhancer(self): def post_prompt_enhancer(self):
while True: while True:
for url in self.config["sub_servers"]["prompt_enhancer"]: for url in self.config["sub_servers"]["prompt_enhancer"]:
......
...@@ -724,6 +724,25 @@ class WanVAE_(nn.Module): ...@@ -724,6 +724,25 @@ class WanVAE_(nn.Module):
self.clear_cache() self.clear_cache()
return out return out
def decode_stream(self, z, scale):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
out = self.decoder(
x[:, :, i : i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx,
)
yield out
def cached_decode(self, z, scale): def cached_decode(self, z, scale):
# z: [b,c,t,h,w] # z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor): if isinstance(scale[0], torch.Tensor):
...@@ -1291,6 +1310,87 @@ class WanVAE: ...@@ -1291,6 +1310,87 @@ class WanVAE:
return images return images
def decode_dist_2d_stream(self, zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w):
total_h = zs.shape[2]
total_w = zs.shape[3]
chunk_h = total_h // world_size_h
chunk_w = total_w // world_size_w
padding_size = 1
# Calculate H dimension slice
if cur_rank_h == 0:
h_start = 0
h_end = chunk_h + 2 * padding_size
elif cur_rank_h == world_size_h - 1:
h_start = total_h - (chunk_h + 2 * padding_size)
h_end = total_h
else:
h_start = cur_rank_h * chunk_h - padding_size
h_end = (cur_rank_h + 1) * chunk_h + padding_size
# Calculate W dimension slice
if cur_rank_w == 0:
w_start = 0
w_end = chunk_w + 2 * padding_size
elif cur_rank_w == world_size_w - 1:
w_start = total_w - (chunk_w + 2 * padding_size)
w_end = total_w
else:
w_start = cur_rank_w * chunk_w - padding_size
w_end = (cur_rank_w + 1) * chunk_w + padding_size
# Extract the latent chunk for this process
zs_chunk = zs[:, :, h_start:h_end, w_start:w_end].contiguous()
for image in self.model.decode_stream(zs_chunk.unsqueeze(0), self.scale):
images_chunk = image.clamp_(-1, 1)
# Remove padding from decoded chunk
spatial_ratio = 8
if cur_rank_h == 0:
decoded_h_start = 0
decoded_h_end = chunk_h * spatial_ratio
elif cur_rank_h == world_size_h - 1:
decoded_h_start = images_chunk.shape[3] - chunk_h * spatial_ratio
decoded_h_end = images_chunk.shape[3]
else:
decoded_h_start = padding_size * spatial_ratio
decoded_h_end = images_chunk.shape[3] - padding_size * spatial_ratio
if cur_rank_w == 0:
decoded_w_start = 0
decoded_w_end = chunk_w * spatial_ratio
elif cur_rank_w == world_size_w - 1:
decoded_w_start = images_chunk.shape[4] - chunk_w * spatial_ratio
decoded_w_end = images_chunk.shape[4]
else:
decoded_w_start = padding_size * spatial_ratio
decoded_w_end = images_chunk.shape[4] - padding_size * spatial_ratio
images_chunk = images_chunk[:, :, :, decoded_h_start:decoded_h_end, decoded_w_start:decoded_w_end].contiguous()
# Gather all chunks
total_processes = world_size_h * world_size_w
full_images = [torch.empty_like(images_chunk) for _ in range(total_processes)]
dist.all_gather(full_images, images_chunk)
self.device_synchronize()
# Reconstruct the full image tensor
image_rows = []
for h_idx in range(world_size_h):
image_cols = []
for w_idx in range(world_size_w):
process_idx = h_idx * world_size_w + w_idx
image_cols.append(full_images[process_idx])
image_rows.append(torch.cat(image_cols, dim=4))
images = torch.cat(image_rows, dim=3)
yield images
def decode(self, zs): def decode(self, zs):
if self.cpu_offload: if self.cpu_offload:
self.to_cuda() self.to_cuda()
...@@ -1324,6 +1424,27 @@ class WanVAE: ...@@ -1324,6 +1424,27 @@ class WanVAE:
return images return images
def decode_stream(self, zs):
if self.cpu_offload:
self.to_cuda()
if self.parallel:
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
latent_height, latent_width = zs.shape[2], zs.shape[3]
world_size_h, world_size_w = self._calculate_2d_grid(latent_height, latent_width, world_size)
cur_rank_h = cur_rank // world_size_w
cur_rank_w = cur_rank % world_size_w
for images in self.decode_dist_2d_stream(zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w):
yield images
else:
for image in self.model.decode_stream(zs.unsqueeze(0), self.scale):
yield image.clamp_(-1, 1)
if self.cpu_offload:
self.to_cpu()
def encode_video(self, vid): def encode_video(self, vid):
return self.model.encode_video(vid) return self.model.encode_video(vid)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment