Unverified Commit c05ebad7 authored by Musisoul's avatar Musisoul Committed by GitHub
Browse files

Support hunyuan parallel vae (#560)

parent 58f84489
...@@ -128,10 +128,75 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -128,10 +128,75 @@ class HunyuanVideo15Runner(DefaultRunner):
target_height // self.config["vae_stride"][1], target_height // self.config["vae_stride"][1],
target_width // self.config["vae_stride"][2], target_width // self.config["vae_stride"][2],
] ]
self.target_height = target_height
self.target_width = target_width ori_latent_h, ori_latent_w = latent_shape[2], latent_shape[3]
if dist.is_initialized() and dist.get_world_size() > 1:
latent_h, latent_w, world_size_h, world_size_w = self._adjust_latent_for_grid_splitting(ori_latent_h, ori_latent_w, dist.get_world_size())
latent_shape[2], latent_shape[3] = latent_h, latent_w
logger.info(f"ori latent: {ori_latent_h}x{ori_latent_w}, adjust_latent: {latent_h}x{latent_w}, grid: {world_size_h}x{world_size_w}")
else:
latent_shape[2], latent_shape[3] = ori_latent_h, ori_latent_w
world_size_h, world_size_w = None, None
self.vae_decoder.world_size_h = world_size_h
self.vae_decoder.world_size_w = world_size_w
self.target_height = latent_shape[2] * self.config["vae_stride"][1]
self.target_width = latent_shape[3] * self.config["vae_stride"][2]
return latent_shape return latent_shape
def _adjust_latent_for_grid_splitting(self, latent_h, latent_w, world_size):
"""
Adjust latent dimensions for optimal 2D grid splitting.
Prefers balanced grids like 2x4 or 4x2 over 1x8 or 8x1.
"""
world_size_h, world_size_w = 1, 1
if world_size <= 1:
return latent_h, latent_w, world_size_h, world_size_w
# Define priority grids for different world sizes
priority_grids = []
if world_size == 8:
# For 8 cards, prefer 2x4 and 4x2 over 1x8 and 8x1
priority_grids = [(2, 4), (4, 2), (1, 8), (8, 1)]
elif world_size == 4:
priority_grids = [(2, 2), (1, 4), (4, 1)]
elif world_size == 2:
priority_grids = [(1, 2), (2, 1)]
else:
# For other sizes, try factor pairs
for h in range(1, int(np.sqrt(world_size)) + 1):
if world_size % h == 0:
w = world_size // h
priority_grids.append((h, w))
# Try priority grids first
for world_size_h, world_size_w in priority_grids:
if latent_h % world_size_h == 0 and latent_w % world_size_w == 0:
return latent_h, latent_w, world_size_h, world_size_w
# If no perfect fit, find minimal padding solution
best_grid = (1, world_size) # fallback
min_total_padding = float("inf")
for world_size_h, world_size_w in priority_grids:
# Calculate required padding
pad_h = (world_size_h - (latent_h % world_size_h)) % world_size_h
pad_w = (world_size_w - (latent_w % world_size_w)) % world_size_w
total_padding = pad_h + pad_w
# Prefer grids with minimal total padding
if total_padding < min_total_padding:
min_total_padding = total_padding
best_grid = (world_size_h, world_size_w)
# Apply padding
world_size_h, world_size_w = best_grid
pad_h = (world_size_h - (latent_h % world_size_h)) % world_size_h
pad_w = (world_size_w - (latent_w % world_size_w)) % world_size_w
return latent_h + pad_h, latent_w + pad_w, world_size_h, world_size_w
def get_sr_latent_shape_with_target_hw(self): def get_sr_latent_shape_with_target_hw(self):
SizeMap = { SizeMap = {
"480p": 640, "480p": 640,
...@@ -254,6 +319,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -254,6 +319,7 @@ class HunyuanVideo15Runner(DefaultRunner):
"device": vae_device, "device": vae_device,
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
"dtype": GET_DTYPE(), "dtype": GET_DTYPE(),
"parallel": self.config["parallel"],
} }
if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]: if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]:
return None return None
...@@ -273,6 +339,7 @@ class HunyuanVideo15Runner(DefaultRunner): ...@@ -273,6 +339,7 @@ class HunyuanVideo15Runner(DefaultRunner):
"device": vae_device, "device": vae_device,
"cpu_offload": vae_offload, "cpu_offload": vae_offload,
"dtype": GET_DTYPE(), "dtype": GET_DTYPE(),
"parallel": self.config["parallel"],
} }
if self.config.get("use_tae", False): if self.config.get("use_tae", False):
tae_path = self.config["tae_path"] tae_path = self.config["tae_path"]
......
...@@ -5,6 +5,7 @@ from typing import Optional, Tuple, Union ...@@ -5,6 +5,7 @@ from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.autoencoders.vae import BaseOutput, DiagonalGaussianDistribution from diffusers.models.autoencoders.vae import BaseOutput, DiagonalGaussianDistribution
...@@ -787,9 +788,11 @@ class AutoencoderKLConv3D(ModelMixin, ConfigMixin): ...@@ -787,9 +788,11 @@ class AutoencoderKLConv3D(ModelMixin, ConfigMixin):
class HunyuanVideo15VAE: class HunyuanVideo15VAE:
def __init__(self, checkpoint_path=None, dtype=torch.float16, device="cuda", cpu_offload=False): def __init__(self, checkpoint_path=None, dtype=torch.float16, device="cuda", cpu_offload=False, parallel=False):
self.vae = AutoencoderKLConv3D.from_pretrained(os.path.join(checkpoint_path, "vae")).to(dtype).to(device) self.vae = AutoencoderKLConv3D.from_pretrained(os.path.join(checkpoint_path, "vae")).to(dtype).to(device)
self.vae.cpu_offload = cpu_offload self.vae.cpu_offload = cpu_offload
self.parallel = parallel
self.world_size_h, self.world_size_w = None, None
@torch.no_grad() @torch.no_grad()
def encode(self, x): def encode(self, x):
...@@ -800,10 +803,105 @@ class HunyuanVideo15VAE: ...@@ -800,10 +803,105 @@ class HunyuanVideo15VAE:
z = z / self.vae.config.scaling_factor z = z / self.vae.config.scaling_factor
self.vae.enable_tiling() self.vae.enable_tiling()
if self.parallel and self.world_size_h is not None and self.world_size_w is not None:
video_frames = self.decode_dist_2d(z, self.world_size_h, self.world_size_w)
self.world_size_h, self.world_size_w = None, None
else:
video_frames = self.vae.decode(z, return_dict=False)[0] video_frames = self.vae.decode(z, return_dict=False)[0]
self.vae.disable_tiling() self.vae.disable_tiling()
return video_frames return video_frames
@torch.no_grad()
def decode_dist_2d(self, z, world_size_h, world_size_w):
cur_rank = dist.get_rank()
cur_rank_h = cur_rank // world_size_w
cur_rank_w = cur_rank % world_size_w
total_h = z.shape[3]
total_w = z.shape[4]
chunk_h = total_h // world_size_h
chunk_w = total_w // world_size_w
padding_size = 1
# Calculate H dimension slice
if cur_rank_h == 0:
h_start = 0
h_end = chunk_h + 2 * padding_size
elif cur_rank_h == world_size_h - 1:
h_start = total_h - (chunk_h + 2 * padding_size)
h_end = total_h
else:
h_start = cur_rank_h * chunk_h - padding_size
h_end = (cur_rank_h + 1) * chunk_h + padding_size
# Calculate W dimension slice
if cur_rank_w == 0:
w_start = 0
w_end = chunk_w + 2 * padding_size
elif cur_rank_w == world_size_w - 1:
w_start = total_w - (chunk_w + 2 * padding_size)
w_end = total_w
else:
w_start = cur_rank_w * chunk_w - padding_size
w_end = (cur_rank_w + 1) * chunk_w + padding_size
# Extract the latent chunk for this process
zs_chunk = z[:, :, :, h_start:h_end, w_start:w_end].contiguous()
# Decode the chunk
images_chunk = self.vae.decode(zs_chunk, return_dict=False)[0]
# Remove padding from decoded chunk
spatial_ratio = 16
if cur_rank_h == 0:
decoded_h_start = 0
decoded_h_end = chunk_h * spatial_ratio
elif cur_rank_h == world_size_h - 1:
decoded_h_start = images_chunk.shape[3] - chunk_h * spatial_ratio
decoded_h_end = images_chunk.shape[3]
else:
decoded_h_start = padding_size * spatial_ratio
decoded_h_end = images_chunk.shape[3] - padding_size * spatial_ratio
if cur_rank_w == 0:
decoded_w_start = 0
decoded_w_end = chunk_w * spatial_ratio
elif cur_rank_w == world_size_w - 1:
decoded_w_start = images_chunk.shape[4] - chunk_w * spatial_ratio
decoded_w_end = images_chunk.shape[4]
else:
decoded_w_start = padding_size * spatial_ratio
decoded_w_end = images_chunk.shape[4] - padding_size * spatial_ratio
images_chunk = images_chunk[:, :, :, decoded_h_start:decoded_h_end, decoded_w_start:decoded_w_end].contiguous()
# Gather all chunks
total_processes = world_size_h * world_size_w
full_images = [torch.empty_like(images_chunk) for _ in range(total_processes)]
dist.all_gather(full_images, images_chunk)
self.device_synchronize()
# Reconstruct the full image tensor
image_rows = []
for h_idx in range(world_size_h):
image_cols = []
for w_idx in range(world_size_w):
process_idx = h_idx * world_size_w + w_idx
image_cols.append(full_images[process_idx])
image_rows.append(torch.cat(image_cols, dim=4))
images = torch.cat(image_rows, dim=3)
return images
def device_synchronize(
self,
):
torch_device_module.synchronize()
if __name__ == "__main__": if __name__ == "__main__":
vae = HunyuanVideo15VAE(checkpoint_path="/data/nvme1/yongyang/models/HunyuanVideo-1.5/ckpts/hunyuanvideo-1.5", dtype=torch.float16, device="cuda") vae = HunyuanVideo15VAE(checkpoint_path="/data/nvme1/yongyang/models/HunyuanVideo-1.5/ckpts/hunyuanvideo-1.5", dtype=torch.float16, device="cuda")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment