Commit c93c756c authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

Support vae encode dist infer & Remove approximate_patch vae for its bad precision. (#255)

parent bba65ffd
# Code source: https://github.com/RiseAI-Sys/ParaVAE/blob/main/paravae/dist/distributed_env.py
import torch.distributed as dist
from torch.distributed import ProcessGroup
import os
class DistributedEnv:
_vae_group = None
_local_rank = None
_world_size = None
@classmethod
def initialize(cls, vae_group: ProcessGroup):
if vae_group is None:
cls._vae_group = dist.group.WORLD
else:
cls._vae_group = vae_group
cls._local_rank = int(os.environ.get('LOCAL_RANK', 0)) # FIXME: in ray all local_rank is 0
cls._rank_mapping = None
cls._init_rank_mapping()
@classmethod
def get_vae_group(cls) -> ProcessGroup:
if cls._vae_group is None:
raise RuntimeError("DistributedEnv not initialized. Call initialize() first.")
return cls._vae_group
@classmethod
def get_global_rank(cls) -> int:
return dist.get_rank()
@classmethod
def _init_rank_mapping(cls):
"""Initialize the mapping between group ranks and global ranks"""
if cls._rank_mapping is None:
# Get all ranks in the group
ranks = [None] * cls.get_group_world_size()
dist.all_gather_object(ranks, cls.get_global_rank(), group=cls.get_vae_group())
cls._rank_mapping = ranks
@classmethod
def get_global_rank_from_group_rank(cls, group_rank: int) -> int:
"""Convert a rank in VAE group to global rank using cached mapping.
Args:
group_rank: The rank in VAE group
Returns:
The corresponding global rank
Raises:
RuntimeError: If the group_rank is invalid
"""
if cls._rank_mapping is None:
cls._init_rank_mapping()
if group_rank < 0 or group_rank >= cls.get_group_world_size():
raise RuntimeError(f"Invalid group rank: {group_rank}. Must be in range [0, {cls.get_group_world_size()-1}]")
return cls._rank_mapping[group_rank]
@classmethod
def get_rank_in_vae_group(cls) -> int:
return dist.get_rank(cls.get_vae_group())
@classmethod
def get_group_world_size(cls) -> int:
return dist.get_world_size(cls.get_vae_group())
@classmethod
def get_local_rank(cls) -> int:
return cls._local_rank
# Code source: https://github.com/RiseAI-Sys/ParaVAE/blob/main/paravae/dist/split_gather.py
import torch
import torch.distributed as dist
from lightx2v.models.video_encoders.hf.wan.dist.distributed_env import DistributedEnv
def _gather(patch_hidden_state, dim=-1, group=None):
group_world_size = DistributedEnv.get_group_world_size()
local_rank = DistributedEnv.get_local_rank()
patch_height_list = [torch.empty([1], dtype=torch.int64, device=f"cuda:{local_rank}") for _ in range(group_world_size)]
dist.all_gather(
patch_height_list,
torch.tensor(
[patch_hidden_state.shape[3]],
dtype=torch.int64,
device=f"cuda:{local_rank}"
),
group=DistributedEnv.get_vae_group()
)
patch_hidden_state_list = [
torch.zeros(
[patch_hidden_state.shape[0], patch_hidden_state.shape[1], patch_hidden_state.shape[2], patch_height_list[i].item(),patch_hidden_state.shape[4]],
dtype=patch_hidden_state.dtype,
device=f"cuda:{local_rank}",
requires_grad=patch_hidden_state.requires_grad
) for i in range(group_world_size)
]
dist.all_gather(
patch_hidden_state_list,
patch_hidden_state.contiguous(),
group=DistributedEnv.get_vae_group()
)
output = torch.cat(patch_hidden_state_list, dim=3)
return output
def _split(inputs, dim=-1, group=None):
group_world_size = DistributedEnv.get_group_world_size()
rank_in_vae_group = DistributedEnv.get_rank_in_vae_group()
height = inputs.shape[3]
start_idx = (height + group_world_size - 1) // group_world_size * rank_in_vae_group
end_idx = min((height + group_world_size - 1) // group_world_size * (rank_in_vae_group + 1), height)
return inputs[:, :, :, start_idx: end_idx, :].clone()
class _SplitForwardGatherBackward(torch.autograd.Function):
"""Split the input.
Args:
inputs: input matrix.
dim: dimension
group: process group
"""
@staticmethod
def forward(ctx, inputs, dim, group):
ctx.group = group
ctx.dim = dim
return _split(inputs, dim, group)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.dim, ctx.group), None, None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
Args:
inputs: input matrix.
dim: dimension
group: process group
"""
@staticmethod
def forward(ctx, inputs, dim, group):
ctx.group = group
ctx.dim = dim
return _gather(inputs, dim, group)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.group), None, None
def split_forward_gather_backward(group, inputs, dim):
return _SplitForwardGatherBackward.apply(inputs, dim, group)
def gather_forward_split_backward(group, inputs, dim):
return _GatherForwardSplitBackward.apply(inputs, dim, group)
......@@ -7,8 +7,6 @@ import torch.nn.functional as F
from einops import rearrange
from loguru import logger
from lightx2v.models.video_encoders.hf.wan.dist.distributed_env import DistributedEnv
from lightx2v.models.video_encoders.hf.wan.dist.split_gather import gather_forward_split_backward, split_forward_gather_backward
from lightx2v.utils.utils import load_weights
__all__ = [
......@@ -519,7 +517,6 @@ class WanVAE_(nn.Module):
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
self.use_approximate_patch = False
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
......@@ -550,12 +547,6 @@ class WanVAE_(nn.Module):
dropout,
)
def enable_approximate_patch(self):
self.use_approximate_patch = True
def disable_approximate_patch(self):
self.use_approximate_patch = False
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
......@@ -638,9 +629,6 @@ class WanVAE_(nn.Module):
return enc
def tiled_decode(self, z, scale):
if self.use_approximate_patch:
z = split_forward_gather_backward(None, z, 3)
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
......@@ -690,8 +678,6 @@ class WanVAE_(nn.Module):
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if self.use_approximate_patch:
dec = gather_forward_split_backward(None, dec, 3)
return dec
......@@ -726,8 +712,6 @@ class WanVAE_(nn.Module):
def decode(self, z, scale):
self.clear_cache()
if self.use_approximate_patch:
z = split_forward_gather_backward(None, z, 3)
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
......@@ -752,9 +736,6 @@ class WanVAE_(nn.Module):
)
out = torch.cat([out, out_], 2)
if self.use_approximate_patch:
out = gather_forward_split_backward(None, out, 3)
self.clear_cache()
return out
......@@ -866,12 +847,6 @@ class WanVAE:
# init model
self.model = _video_vae(pretrained_path=vae_pth, z_dim=z_dim, cpu_offload=cpu_offload).eval().requires_grad_(False).to(device)
self.use_approximate_patch = False
if self.parallel and self.parallel.get("use_patch_vae", False):
# assert not self.use_tiling
DistributedEnv.initialize(None)
self.use_approximate_patch = True
self.model.enable_approximate_patch()
def current_device(self):
return next(self.model.parameters()).device
......@@ -892,6 +867,70 @@ class WanVAE:
self.inv_std = self.inv_std.cuda()
self.scale = [self.mean, self.inv_std]
def encode_dist(self, video, world_size, cur_rank, split_dim):
spatial_ratio = 8
if split_dim == 3:
total_latent_len = video.shape[3] // spatial_ratio
elif split_dim == 4:
total_latent_len = video.shape[4] // spatial_ratio
else:
raise ValueError(f"Unsupported split_dim: {split_dim}")
splited_chunk_len = total_latent_len // world_size
padding_size = 1
video_chunk_len = splited_chunk_len * spatial_ratio
video_padding_len = padding_size * spatial_ratio
if cur_rank == 0:
if split_dim == 3:
video_chunk = video[:, :, :, : video_chunk_len + 2 * video_padding_len, :].contiguous()
elif split_dim == 4:
video_chunk = video[:, :, :, :, : video_chunk_len + 2 * video_padding_len].contiguous()
elif cur_rank == world_size - 1:
if split_dim == 3:
video_chunk = video[:, :, :, -(video_chunk_len + 2 * video_padding_len) :, :].contiguous()
elif split_dim == 4:
video_chunk = video[:, :, :, :, -(video_chunk_len + 2 * video_padding_len) :].contiguous()
else:
start_idx = cur_rank * video_chunk_len - video_padding_len
end_idx = (cur_rank + 1) * video_chunk_len + video_padding_len
if split_dim == 3:
video_chunk = video[:, :, :, start_idx:end_idx, :].contiguous()
elif split_dim == 4:
video_chunk = video[:, :, :, :, start_idx:end_idx].contiguous()
if self.use_tiling:
encoded_chunk = self.model.tiled_encode(video_chunk, self.scale).float()
else:
encoded_chunk = self.model.encode(video_chunk, self.scale).float()
if cur_rank == 0:
if split_dim == 3:
encoded_chunk = encoded_chunk[:, :, :, :splited_chunk_len, :].contiguous()
elif split_dim == 4:
encoded_chunk = encoded_chunk[:, :, :, :, :splited_chunk_len].contiguous()
elif cur_rank == world_size - 1:
if split_dim == 3:
encoded_chunk = encoded_chunk[:, :, :, -splited_chunk_len:, :].contiguous()
elif split_dim == 4:
encoded_chunk = encoded_chunk[:, :, :, :, -splited_chunk_len:].contiguous()
else:
if split_dim == 3:
encoded_chunk = encoded_chunk[:, :, :, padding_size:-padding_size, :].contiguous()
elif split_dim == 4:
encoded_chunk = encoded_chunk[:, :, :, :, padding_size:-padding_size].contiguous()
full_encoded = [torch.empty_like(encoded_chunk) for _ in range(world_size)]
dist.all_gather(full_encoded, encoded_chunk)
torch.cuda.synchronize()
encoded = torch.cat(full_encoded, dim=split_dim)
return encoded.squeeze(0)
def encode(self, video):
"""
video: one video with shape [1, C, T, H, W].
......@@ -899,10 +938,27 @@ class WanVAE:
if self.cpu_offload:
self.to_cuda()
if self.use_tiling:
out = self.model.tiled_encode(video, self.scale).float().squeeze(0)
if self.parallel:
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
height, width = video.shape[3], video.shape[4]
# Check if dimensions are divisible by world_size
if width % world_size == 0:
out = self.encode_dist(video, world_size, cur_rank, split_dim=4)
elif height % world_size == 0:
out = self.encode_dist(video, world_size, cur_rank, split_dim=3)
else:
logger.info("Fall back to naive encode mode")
if self.use_tiling:
out = self.model.tiled_encode(video, self.scale).float().squeeze(0)
else:
out = self.model.encode(video, self.scale).float().squeeze(0)
else:
out = self.model.encode(video, self.scale).float().squeeze(0)
if self.use_tiling:
out = self.model.tiled_encode(video, self.scale).float().squeeze(0)
else:
out = self.model.encode(video, self.scale).float().squeeze(0)
if self.cpu_offload:
self.to_cpu()
......@@ -961,7 +1017,7 @@ class WanVAE:
if self.cpu_offload:
self.to_cuda()
if self.parallel and not self.use_approximate_patch:
if self.parallel:
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
height, width = zs.shape[2], zs.shape[3]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment