Commit c158e550 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[fea] Update patch vae

[fea] Update patch vae
parents d061ae81 9908528b
# Code source: https://github.com/RiseAI-Sys/ParaVAE/blob/main/paravae/dist/distributed_env.py
import torch.distributed as dist
from torch.distributed import ProcessGroup
import os
class DistributedEnv:
_vae_group = None
_local_rank = None
_world_size = None
@classmethod
def initialize(cls, vae_group: ProcessGroup):
if vae_group is None:
cls._vae_group = dist.group.WORLD
else:
cls._vae_group = vae_group
cls._local_rank = int(os.environ.get('LOCAL_RANK', 0)) # FIXME: in ray all local_rank is 0
cls._rank_mapping = None
cls._init_rank_mapping()
@classmethod
def get_vae_group(cls) -> ProcessGroup:
if cls._vae_group is None:
raise RuntimeError("DistributedEnv not initialized. Call initialize() first.")
return cls._vae_group
@classmethod
def get_global_rank(cls) -> int:
return dist.get_rank()
@classmethod
def _init_rank_mapping(cls):
"""Initialize the mapping between group ranks and global ranks"""
if cls._rank_mapping is None:
# Get all ranks in the group
ranks = [None] * cls.get_group_world_size()
dist.all_gather_object(ranks, cls.get_global_rank(), group=cls.get_vae_group())
cls._rank_mapping = ranks
@classmethod
def get_global_rank_from_group_rank(cls, group_rank: int) -> int:
"""Convert a rank in VAE group to global rank using cached mapping.
Args:
group_rank: The rank in VAE group
Returns:
The corresponding global rank
Raises:
RuntimeError: If the group_rank is invalid
"""
if cls._rank_mapping is None:
cls._init_rank_mapping()
if group_rank < 0 or group_rank >= cls.get_group_world_size():
raise RuntimeError(f"Invalid group rank: {group_rank}. Must be in range [0, {cls.get_group_world_size()-1}]")
return cls._rank_mapping[group_rank]
@classmethod
def get_rank_in_vae_group(cls) -> int:
return dist.get_rank(cls.get_vae_group())
@classmethod
def get_group_world_size(cls) -> int:
return dist.get_world_size(cls.get_vae_group())
@classmethod
def get_local_rank(cls) -> int:
return cls._local_rank
# Code source: https://github.com/RiseAI-Sys/ParaVAE/blob/main/paravae/dist/split_gather.py
import torch
import torch.distributed as dist
from lightx2v.models.video_encoders.hf.wan.dist.distributed_env import DistributedEnv
def _gather(patch_hidden_state, dim=-1, group=None):
group_world_size = DistributedEnv.get_group_world_size()
local_rank = DistributedEnv.get_local_rank()
patch_height_list = [torch.empty([1], dtype=torch.int64, device=f"cuda:{local_rank}") for _ in range(group_world_size)]
dist.all_gather(
patch_height_list,
torch.tensor(
[patch_hidden_state.shape[3]],
dtype=torch.int64,
device=f"cuda:{local_rank}"
),
group=DistributedEnv.get_vae_group()
)
patch_hidden_state_list = [
torch.zeros(
[patch_hidden_state.shape[0], patch_hidden_state.shape[1], patch_hidden_state.shape[2], patch_height_list[i].item(),patch_hidden_state.shape[4]],
dtype=patch_hidden_state.dtype,
device=f"cuda:{local_rank}",
requires_grad=patch_hidden_state.requires_grad
) for i in range(group_world_size)
]
dist.all_gather(
patch_hidden_state_list,
patch_hidden_state.contiguous(),
group=DistributedEnv.get_vae_group()
)
output = torch.cat(patch_hidden_state_list, dim=3)
return output
def _split(inputs, dim=-1, group=None):
group_world_size = DistributedEnv.get_group_world_size()
rank_in_vae_group = DistributedEnv.get_rank_in_vae_group()
height = inputs.shape[3]
start_idx = (height + group_world_size - 1) // group_world_size * rank_in_vae_group
end_idx = min((height + group_world_size - 1) // group_world_size * (rank_in_vae_group + 1), height)
return inputs[:, :, :, start_idx: end_idx, :].clone()
class _SplitForwardGatherBackward(torch.autograd.Function):
"""Split the input.
Args:
inputs: input matrix.
dim: dimension
group: process group
"""
@staticmethod
def forward(ctx, inputs, dim, group):
ctx.group = group
ctx.dim = dim
return _split(inputs, dim, group)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.dim, ctx.group), None, None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
Args:
inputs: input matrix.
dim: dimension
group: process group
"""
@staticmethod
def forward(ctx, inputs, dim, group):
ctx.group = group
ctx.dim = dim
return _gather(inputs, dim, group)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.group), None, None
def split_forward_gather_backward(group, inputs, dim):
return _SplitForwardGatherBackward.apply(inputs, dim, group)
def gather_forward_split_backward(group, inputs, dim):
return _GatherForwardSplitBackward.apply(inputs, dim, group)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment