Commit c07946d8 authored by hepj's avatar hepj
Browse files

dit & video

parents
from dataclasses import dataclass
from fastvideo.v1.configs.sample.base import SamplingParam
@dataclass
class HunyuanSamplingParam(SamplingParam):
num_inference_steps: int = 50
num_frames: int = 125
height: int = 720
width: int = 1280
fps: int = 24
guidance_scale: float = 1.0
@dataclass
class FastHunyuanSamplingParam(HunyuanSamplingParam):
num_inference_steps: int = 6
import os
from typing import Any, Callable, Dict, Optional
from fastvideo.v1.configs.sample.hunyuan import (FastHunyuanSamplingParam,
HunyuanSamplingParam)
from fastvideo.v1.configs.sample.wan import (WanI2V480PSamplingParam,
WanT2V480PSamplingParam)
from fastvideo.v1.logger import init_logger
from fastvideo.v1.utils import (maybe_download_model_index,
verify_model_config_and_directory)
logger = init_logger(__name__)
# Registry maps specific model weights to their config classes
SAMPLING_PARAM_REGISTRY: Dict[str, Any] = {
"FastVideo/FastHunyuan-diffusers": FastHunyuanSamplingParam,
"hunyuanvideo-community/HunyuanVideo": HunyuanSamplingParam,
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers": WanT2V480PSamplingParam,
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers": WanI2V480PSamplingParam
# Add other specific weight variants
}
# For determining pipeline type from model ID
SAMPLING_PARAM_DETECTOR: Dict[str, Callable[[str], bool]] = {
"hunyuan": lambda id: "hunyuan" in id.lower(),
"wanpipeline": lambda id: "wanpipeline" in id.lower(),
"wanimagetovideo": lambda id: "wanimagetovideo" in id.lower(),
# Add other pipeline architecture detectors
}
# Fallback configs when exact match isn't found but architecture is detected
SAMPLING_FALLBACK_PARAM: Dict[str, Any] = {
"hunyuan":
HunyuanSamplingParam, # Base Hunyuan config as fallback for any Hunyuan variant
"wanpipeline":
WanT2V480PSamplingParam, # Base Wan config as fallback for any Wan variant
"wanimagetovideo": WanI2V480PSamplingParam,
# Other fallbacks by architecture
}
def get_sampling_param_cls_for_name(
pipeline_name_or_path: str) -> Optional[Any]:
"""Get the appropriate sampling param for specific pretrained weights."""
if os.path.exists(pipeline_name_or_path):
config = verify_model_config_and_directory(pipeline_name_or_path)
logger.warning(
"FastVideo may not correctly identify the optimal sampling param for this model, as the local directory may have been renamed."
)
else:
config = maybe_download_model_index(pipeline_name_or_path)
pipeline_name = config["_class_name"]
# First try exact match for specific weights
if pipeline_name_or_path in SAMPLING_PARAM_REGISTRY:
return SAMPLING_PARAM_REGISTRY[pipeline_name_or_path]
# Try partial matches (for local paths that might include the weight ID)
for registered_id, config_class in SAMPLING_PARAM_REGISTRY.items():
if registered_id in pipeline_name_or_path:
return config_class
# If no match, try to use the fallback config
fallback_config = None
# Try to determine pipeline architecture for fallback
for pipeline_type, detector in SAMPLING_PARAM_DETECTOR.items():
if detector(pipeline_name.lower()):
fallback_config = SAMPLING_FALLBACK_PARAM.get(pipeline_type)
break
logger.warning(
"No match found for pipeline %s, using fallback sampling param %s.",
pipeline_name_or_path, fallback_config)
return fallback_config
from dataclasses import dataclass
from fastvideo.v1.configs.sample.base import SamplingParam
@dataclass
class WanT2V480PSamplingParam(SamplingParam):
# Video parameters
height: int = 480
width: int = 832
num_frames: int = 81
fps: int = 16
# Denoising stage
guidance_scale: float = 3.0
negative_prompt: str = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_inference_steps: int = 50
@dataclass
class WanI2V480PSamplingParam(WanT2V480PSamplingParam):
# Denoising stage
guidance_scale: float = 5.0
num_inference_steps: int = 40
num_gpus: 4
model_path: FastVideo/FastHunyuan-diffusers
master_port: 29503
sp_size: 4
tp_size: 4
height: 720
width: 1280
num_frames: 125
num_inference_steps: 6
guidance_scale: 1
embedded_cfg_scale: 6
flow_shift: 17
prompt_path: ./assets/prompt.txt
seed: 1024
output_path: outputs_video/
vae-sp: True
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
from fastvideo.v1.distributed.communication_op import *
from fastvideo.v1.distributed.parallel_state import (
cleanup_dist_env_and_memory, get_sequence_model_parallel_rank,
get_sequence_model_parallel_world_size, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_world_group,
init_distributed_environment, initialize_model_parallel)
from fastvideo.v1.distributed.utils import *
__all__ = [
"init_distributed_environment",
"initialize_model_parallel",
"get_sequence_model_parallel_rank",
"get_sequence_model_parallel_world_size",
"get_tensor_model_parallel_rank",
"get_tensor_model_parallel_world_size",
"cleanup_dist_env_and_memory",
"get_world_group",
]
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/communication_op.py
import torch
import torch.distributed
from fastvideo.v1.distributed.parallel_state import get_sp_group, get_tp_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)
def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_tp_group().all_gather(input_, dim)
# TODO: remove model, make it sequence_parallel
def sequence_model_parallel_all_to_all_4D(input_: torch.Tensor,
scatter_dim: int = 2,
gather_dim: int = 1) -> torch.Tensor:
"""All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group."""
return get_sp_group().all_to_all_4D(input_, scatter_dim, gather_dim)
def sequence_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_sp_group().all_gather(input_, dim)
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/base_device_communicator.py
from typing import Optional
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
class DeviceCommunicatorBase:
"""
Base class for device-specific communicator.
It can use the `cpu_group` to initialize the communicator.
If the device has PyTorch integration (PyTorch can recognize its
communication backend), the `device_group` will also be given.
"""
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def all_to_all_4D(self,
input_: torch.Tensor,
scatter_dim: int = 2,
gather_dim: int = 1) -> torch.Tensor:
"""Specialized all-to-all operation for 4D tensors (e.g., for QKV matrices).
Args:
input_ (torch.Tensor): 4D input tensor to be scattered and gathered.
scatter_dim (int, optional): Dimension along which to scatter. Defaults to 2.
gather_dim (int, optional): Dimension along which to gather. Defaults to 1.
Returns:
torch.Tensor: Output tensor after all-to-all operation.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
assert input_.dim(
) == 4, f"input must be 4D tensor, got {input_.dim()} and shape {input_.shape}"
if scatter_dim == 2 and gather_dim == 1:
# input: (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
bs, shard_seqlen, hc, hs = input_.shape
seqlen = shard_seqlen * self.world_size
shard_hc = hc // self.world_size
# Reshape and transpose for scattering
input_t = (input_.reshape(bs, shard_seqlen, self.world_size,
shard_hc, hs).transpose(0,
2).contiguous())
output = torch.empty_like(input_t)
torch.distributed.all_to_all_single(output,
input_t,
group=self.device_group)
torch.cuda.synchronize()
# Reshape and transpose back
output = output.reshape(seqlen, bs, shard_hc,
hs).transpose(0, 1).contiguous().reshape(
bs, seqlen, shard_hc, hs)
return output
elif scatter_dim == 1 and gather_dim == 2:
# input: (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
bs, seqlen, shard_hc, hs = input_.shape
hc = shard_hc * self.world_size
shard_seqlen = seqlen // self.world_size
# Reshape and transpose for scattering
input_t = (input_.reshape(bs, self.world_size, shard_seqlen,
shard_hc, hs).transpose(0, 3).transpose(
0, 1).contiguous().reshape(
self.world_size, shard_hc,
shard_seqlen, bs, hs))
output = torch.empty_like(input_t)
torch.distributed.all_to_all_single(output,
input_t,
group=self.device_group)
torch.cuda.synchronize()
# Reshape and transpose back
output = output.reshape(hc, shard_seqlen, bs,
hs).transpose(0, 2).contiguous().reshape(
bs, shard_seqlen, hc, hs)
return output
else:
raise RuntimeError(
"scatter_dim must be 1 or 2 and gather_dim must be 1 or 2")
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self) -> None:
pass
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/cuda_communicator.py
from typing import Optional
import torch
from torch.distributed import ProcessGroup
from fastvideo.v1.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
class CudaCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
from fastvideo.v1.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
)
def all_reduce(self, input_):
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self) -> None:
if self.pynccl_comm is not None:
self.pynccl_comm = None
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/pynccl.py
from typing import Optional, Union
# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
from fastvideo.v1.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
ncclRedOpTypeEnum, ncclUniqueId)
from fastvideo.v1.distributed.utils import StatelessProcessGroup
from fastvideo.v1.logger import init_logger
from fastvideo.v1.utils import current_stream
logger = init_logger(__name__)
class PyNcclCommunicator:
def __init__(
self,
group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device],
library_path: Optional[str] = None,
):
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the PyNcclCommunicator to. If None,
it will be bind to f"cuda:{local_rank}".
library_path: the path to the NCCL library. If None, it will
use the default library path.
It is the caller's responsibility to make sure each communicator
is bind to a unique device.
"""
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized()
assert dist.get_backend(group) != dist.Backend.NCCL, (
"PyNcclCommunicator should be attached to a non-NCCL group.")
# note: this rank is the rank in the group
self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group)
else:
self.rank = group.rank
self.world_size = group.world_size
self.group = group
# if world_size == 1, no need to create communicator
if self.world_size == 1:
self.available = False
self.disabled = True
return
try:
self.nccl = NCCLLibrary(library_path)
except Exception:
# disable because of missing NCCL library
# e.g. in a non-GPU environment
self.available = False
self.disabled = True
return
self.available = True
self.disabled = False
logger.info("FastVideo is using nccl==%s", self.nccl.ncclGetVersion())
if self.rank == 0:
# get the unique id from NCCL
self.unique_id = self.nccl.ncclGetUniqueId()
else:
# construct an empty unique id
self.unique_id = ncclUniqueId()
if not isinstance(group, StatelessProcessGroup):
tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank
dist.broadcast(tensor, src=ranks[0], group=group)
byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with torch.cuda.device(device):
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank)
stream = current_stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
stream.synchronize()
del data
def all_reduce(self,
in_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor:
if self.disabled:
return None
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert in_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {in_tensor.device}")
out_tensor = torch.empty_like(in_tensor)
if stream is None:
stream = current_stream()
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
ncclDataTypeEnum.from_torch(in_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
return out_tensor
def all_gather(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()
self.nccl.ncclAllGather(buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
self.comm, cudaStream_t(stream.cuda_stream))
def reduce_scatter(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))
def recv(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer
recvbuff = buffer_type(tensor.data_ptr())
else:
sendbuff = buffer_type()
recvbuff = buffer_type(tensor.data_ptr())
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/pynccl_wrapper.py
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `FASTVIDEO_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
#TODO(will): support FASTVIDEO_NCCL_SO_PATH
import ctypes
import platform
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from torch.distributed import ReduceOp
from fastvideo.v1.logger import init_logger
from fastvideo.v1.utils import find_nccl_library
logger = init_logger(__name__)
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t = ctypes.c_int
ncclComm_t = ctypes.c_void_p
class ncclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
cudaStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p
ncclDataType_t = ctypes.c_int
class ncclDataTypeEnum:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
ncclInt32 = 2
ncclInt = 2
ncclUint32 = 3
ncclInt64 = 4
ncclUint64 = 5
ncclFloat16 = 6
ncclHalf = 6
ncclFloat32 = 7
ncclFloat = 7
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
return cls.ncclUint8
if dtype == torch.int32:
return cls.ncclInt32
if dtype == torch.int64:
return cls.ncclInt64
if dtype == torch.float16:
return cls.ncclFloat16
if dtype == torch.float32:
return cls.ncclFloat32
if dtype == torch.float64:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0
ncclProd = 1
ncclMax = 2
ncclMin = 3
ncclAvg = 4
ncclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
return cls.ncclProd
if op == ReduceOp.MAX:
return cls.ncclMax
if op == ReduceOp.MIN:
return cls.ncclMin
if op == ReduceOp.AVG:
return cls.ncclAvg
raise ValueError(f"Unsupported op: {op}")
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
class NCCLLibrary:
exported_functions = [
# const char* ncclGetErrorString(ncclResult_t result)
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
# ncclResult_t ncclGetVersion(int *version);
Function("ncclGetVersion", ncclResult_t,
[ctypes.POINTER(ctypes.c_int)]),
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
Function("ncclGetUniqueId", ncclResult_t,
[ctypes.POINTER(ncclUniqueId)]),
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
Function("ncclCommInitRank", ncclResult_t, [
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int
]),
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclAllReduce", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclAllGather", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclReduceScatter", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Function("ncclSend", ncclResult_t, [
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclRecv(
# void* recvbuff, size_t count, ncclDataType_t datatype,
# int src, ncclComm_t comm, cudaStream_t stream);
Function("ncclRecv", ncclResult_t, [
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclBroadcast(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, int root, ncclComm_t comm,
# cudaStream_t stream);
Function("ncclBroadcast", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ctypes.c_int, ncclComm_t, cudaStream_t
]),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
so_file = so_file or find_nccl_library()
try:
if so_file not in NCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
NCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = NCCLLibrary.path_to_library_cache[so_file]
except Exception as e:
logger.error(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"If you already have the library, please set the "
"environment variable FASTVIDEO_NCCL_SO_PATH"
" to point to the correct nccl library path.", so_file,
platform.platform())
raise e
if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {}
for func in NCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
def ncclGetErrorString(self, result: ncclResult_t) -> str:
return str(self._funcs["ncclGetErrorString"](result).decode("utf-8"))
def NCCL_CHECK(self, result: ncclResult_t) -> None:
if result != 0:
error_str = self.ncclGetErrorString(result)
raise RuntimeError(f"NCCL error: {error_str}")
def ncclGetVersion(self) -> str:
version = ctypes.c_int()
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
version_str = str(version.value)
# something like 21903 --> "2.19.3"
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
patch = version_str[3:].lstrip("0")
return f"{major}.{minor}.{patch}"
def ncclGetUniqueId(self) -> ncclUniqueId:
unique_id = ncclUniqueId()
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id)))
return unique_id
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
rank: int) -> ncclComm_t:
comm = ncclComm_t()
self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
world_size, unique_id,
rank))
return comm
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
datatype, op, comm,
stream))
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff,
count, datatype, op,
comm, stream))
def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count,
datatype, comm, stream))
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, dest,
comm, stream))
def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
comm, stream))
def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, root: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count,
datatype, root, comm,
stream))
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
__all__ = [
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
"ncclComm_t", "cudaStream_t", "buffer_type"
]
# SPDX-License-Identifier: Apache-2.0
# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/parallel_state.py
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Adapted from
"""FastVideo distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:
- call `init_distributed_environment` to initialize the distributed environment.
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
initialize the model parallel groups.
- any code dealing with the distributed stuff
- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.
If you only need to use the distributed environment without model parallelism,
you can skip the model parallel initialization and destruction steps.
"""
import contextlib
import gc
import pickle
import weakref
from collections import namedtuple
from contextlib import contextmanager
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
import fastvideo.v1.envs as envs
from fastvideo.v1.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from fastvideo.v1.distributed.device_communicators.cuda_communicator import (
CudaCommunicator)
from fastvideo.v1.distributed.utils import StatelessProcessGroup
from fastvideo.v1.logger import init_logger
logger = init_logger(__name__)
@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list: List[Tuple[str, Any]] = []
tensor_list: List[torch.Tensor] = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = value.device.type
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size())))
tensor_list.append(value)
else:
metadata_list.append((key, value))
return metadata_list, tensor_list
_group_name_counter: Dict[str, int] = {}
def _get_unique_name(name: str) -> str:
"""Get a unique name for the group.
Example:
_get_unique_name("tp") -> "tp:0"
_get_unique_name("tp") -> "tp:1"
"""
if name not in _group_name_counter:
_group_name_counter[name] = 0
newname = f"{name}:{_group_name_counter[name]}"
_group_name_counter[name] += 1
return newname
_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
def _register_group(group: "GroupCoordinator") -> None:
_groups[group.unique_name] = weakref.ref(group)
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor)
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)
class GroupCoordinator:
"""
PyTorch ProcessGroup wrapper for a group of processes.
PyTorch ProcessGroup is bound to one specific communication backend,
e.g. NCCL, Gloo, MPI, etc.
GroupCoordinator takes charge of all the communication operations among
the processes in the group. It manages both CPU and device
communication.
"""
# available attributes:
rank: int # global rank
ranks: List[int] # global ranks in the group
world_size: int # size of the group
# difference between `local_rank` and `rank_in_group`:
# if we have a group of size 4 across two nodes:
# Process | Node | Rank | Local Rank | Rank in Group
# 0 | 0 | 0 | 0 | 0
# 1 | 0 | 1 | 1 | 1
# 2 | 1 | 2 | 0 | 2
# 3 | 1 | 3 | 1 | 3
local_rank: int # local rank used to assign devices
rank_in_group: int # rank inside the group
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
use_device_communicator: bool # whether to use device communicator
device_communicator: DeviceCommunicatorBase # device communicator
mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__(
self,
group_ranks: List[List[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_device_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name)
_register_group(self)
self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None
for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self.device_group = device_group
self.cpu_group = cpu_group
assert self.cpu_group is not None
assert self.device_group is not None
from fastvideo.v1.platforms import current_platform
# TODO: fix it for other platforms
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
else:
self.device = torch.device("cpu")
self.use_device_communicator = use_device_communicator
self.device_communicator: DeviceCommunicatorBase = None # type: ignore
if use_device_communicator and self.world_size > 1:
# device_comm_cls = resolve_obj_by_qualname(
# current_platform.get_device_communicator_cls())
self.device_communicator = CudaCommunicator(
cpu_group=self.cpu_group,
device=self.device,
device_group=self.device_group,
unique_name=self.unique_name,
)
self.mq_broadcaster = None
from fastvideo.v1.platforms import current_platform
# TODO(will): check if this is needed
# self.use_custom_op_call = current_platform.is_cuda_alike()
self.use_custom_op_call = False
@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
return self.ranks[0]
@property
def last_rank(self):
"""Return the global rank of the last process in the group"""
return self.ranks[-1]
@property
def is_first_rank(self):
"""Return whether the caller is the first process in the group"""
return self.rank == self.first_rank
@property
def is_last_rank(self):
"""Return whether the caller is the last process in the group"""
return self.rank == self.last_rank
@property
def next_rank(self):
"""Return the global rank of the process that follows the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return self.ranks[(rank_in_group + 1) % world_size]
@property
def prev_rank(self):
"""Return the global rank of the process that precedes the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return self.ranks[(rank_in_group - 1) % world_size]
@contextmanager
def graph_capture(
self, graph_capture_context: Optional[GraphCaptureContext] = None):
if graph_capture_context is None:
stream = torch.cuda.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.cuda.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.cuda.stream(stream):
yield graph_capture_context
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we always make the all-reduce operation
out-of-place.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
if self.use_custom_op_call:
return torch.ops.vllm.all_reduce(input_,
group_name=self.unique_name)
else:
return self._all_reduce_out_place(input_)
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
return self.device_communicator.all_reduce(input_)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
return self.device_communicator.all_gather(input_, dim)
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
return self.device_communicator.gather(input_, dst, dim)
def all_to_all_4D(self,
input_: torch.Tensor,
scatter_dim: int = 2,
gather_dim: int = 1) -> torch.Tensor:
if self.world_size == 1:
return input_
return self.device_communicator.all_to_all_4D(input_, scatter_dim,
gather_dim)
def broadcast(self, input_: torch.Tensor, src: int = 0):
"""Broadcast the input tensor.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
# Broadcast.
torch.distributed.broadcast(input_,
src=self.ranks[src],
group=self.device_group)
return input_
def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
"""Broadcast the input object.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return obj
if self.mq_broadcaster is not None:
assert src == 0, "Message queue broadcaster only supports src=0"
return self.mq_broadcaster.broadcast_object(obj)
if self.rank_in_group == src:
torch.distributed.broadcast_object_list([obj],
src=self.ranks[src],
group=self.cpu_group)
return obj
else:
recv = [None]
torch.distributed.broadcast_object_list(recv,
src=self.ranks[src],
group=self.cpu_group)
return recv[0]
def broadcast_object_list(self,
obj_list: List[Any],
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input object list.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return obj_list
# Broadcast.
torch.distributed.broadcast_object_list(obj_list,
src=self.ranks[src],
group=self.device_group)
return obj_list
def send_object(self, obj: Any, dst: int) -> None:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""
assert dst < self.world_size, f"Invalid dst rank ({dst})"
assert dst != self.rank_in_group, (
"Invalid destination rank. Destination rank is the same "
"as the current rank.")
# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
size_tensor = torch.tensor([object_tensor.numel()],
dtype=torch.long,
device="cpu")
# Send object size
torch.distributed.send(size_tensor,
dst=self.ranks[dst],
group=self.cpu_group)
# Send object
torch.distributed.send(object_tensor,
dst=self.ranks[dst],
group=self.cpu_group)
return None
def recv_object(self, src: int) -> Any:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
assert src < self.world_size, f"Invalid src rank ({src})"
assert src != self.rank_in_group, (
"Invalid source rank. Source rank is the same as the current rank.")
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
# Receive object size
rank_size = torch.distributed.recv(size_tensor,
src=self.ranks[src],
group=self.cpu_group)
# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device="cpu")
rank_object = torch.distributed.recv(object_tensor,
src=self.ranks[src],
group=self.cpu_group)
assert rank_object == rank_size, (
"Received object sender rank does not match the size sender rank.")
obj = pickle.loads(object_tensor.numpy().tobytes())
return obj
def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized() or self.world_size == 1):
return tensor_dict
group = self.device_group
metadata_group = self.cpu_group
assert src < self.world_size, f"Invalid src rank ({src})"
rank_in_group = self.rank_in_group
if rank_in_group == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.broadcast_object(metadata_list, src=src)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=self.ranks[src],
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=self.ranks[src],
group=group,
async_op=True)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
else:
metadata_list = self.broadcast_object(None, src=src)
tensor_dict = {}
async_handles = []
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(
tensor,
src=self.ranks[src],
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(
tensor,
src=self.ranks[src],
group=group,
async_op=True)
async_handles.append(handle)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
def send_tensor_dict(
self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
# send-allgather: send only a slice, then do allgather.
if (all_gather_group is not None
and tensor.numel() % all_gather_size == 0):
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor,
dst=self.ranks[dst],
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
return None
def recv_tensor_dict(
self,
src: Optional[int] = None,
all_gather_group: Optional["GroupCoordinator"] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
all_gather_size = (1 if all_gather_group is None else
all_gather_group.world_size)
all_gather_rank = (0 if all_gather_group is None else
all_gather_group.rank_in_group)
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict: Dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (all_gather_group is not None
and tensor.numel() % all_gather_size == 0)
if use_all_gather:
orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size,
-1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0)
tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict
def barrier(self):
"""Barrier synchronization among the group.
NOTE: don't use `device_group` here! `barrier` in NCCL is
terrible because it is internally a broadcast operation with
secretly created GPU tensors. It is easy to mess up the current
device. Use the CPU group instead.
"""
torch.distributed.barrier(group=self.cpu_group)
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
self.device_communicator.send(tensor, dst)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
return self.device_communicator.recv(size, dtype, src)
def destroy(self) -> None:
if self.device_group is not None:
torch.distributed.destroy_process_group(self.device_group)
self.device_group = None
if self.cpu_group is not None:
torch.distributed.destroy_process_group(self.cpu_group)
self.cpu_group = None
if self.device_communicator is not None:
self.device_communicator.destroy()
if self.mq_broadcaster is not None:
self.mq_broadcaster = None
_WORLD: Optional[GroupCoordinator] = None
def get_world_group() -> GroupCoordinator:
assert _WORLD is not None, ("world group is not initialized")
return _WORLD
def init_world_group(ranks: List[int], local_rank: int,
backend: str) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=[ranks],
local_rank=local_rank,
torch_distributed_backend=backend,
use_device_communicator=False,
group_name="world",
)
def init_model_parallel_group(
group_ranks: List[List[int]],
local_rank: int,
backend: str,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_device_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)
_TP: Optional[GroupCoordinator] = None
def get_tp_group() -> GroupCoordinator:
assert _TP is not None, ("tensor model parallel group is not initialized")
return _TP
# kept for backward compatibility
get_tensor_model_parallel_group = get_tp_group
_ENABLE_CUSTOM_ALL_REDUCE = True
def set_custom_all_reduce(enable: bool):
global _ENABLE_CUSTOM_ALL_REDUCE
_ENABLE_CUSTOM_ALL_REDUCE = enable
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
distributed_init_method: str = "env://",
local_rank: int = -1,
backend: str = "nccl",
):
logger.debug(
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
distributed_init_method, backend)
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
"distributed environment")
# this backend is used for WORLD
torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1:
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
if distributed_init_method == "env://":
local_rank = envs.LOCAL_RANK
else:
local_rank = rank
global _WORLD
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
else:
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size")
_SP: Optional[GroupCoordinator] = None
def get_sp_group() -> GroupCoordinator:
assert _SP is not None, ("sequence model parallel group is not initialized")
return _SP
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
sequence_model_parallel_size: int = 1,
backend: Optional[str] = None,
) -> None:
"""
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
sequence_model_parallel_size: number of GPUs used for sequence model
parallelism.
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
num_tensor_model_parallel_groups: int = (world_size //
tensor_model_parallel_size)
global _TP
assert _TP is None, ("tensor model parallel group is already initialized")
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = list(
range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size))
group_ranks.append(ranks)
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="tp")
# Build the sequence model-parallel groups.
num_sequence_model_parallel_groups: int = (world_size //
sequence_model_parallel_size)
global _SP
assert _SP is None, ("sequence model parallel group is already initialized")
group_ranks = []
# Since SP is incompatible with TP and PP, we can use a simpler group creation logic
for i in range(num_sequence_model_parallel_groups):
# Create groups of consecutive ranks
ranks = list(
range(i * sequence_model_parallel_size,
(i + 1) * sequence_model_parallel_size))
group_ranks.append(ranks)
_SP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="sp")
def get_sequence_model_parallel_world_size() -> int:
"""Return world size for the sequence model parallel group."""
return get_sp_group().world_size
def get_sequence_model_parallel_rank() -> int:
"""Return my rank for the sequence model parallel group."""
return get_sp_group().rank_in_group
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
sequence_model_parallel_size: int,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel, sequence-parallel sizes
are equal to expected values if the model parallel groups are initialized.
"""
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
sequence_model_parallel_size, backend)
return
assert (
get_tensor_model_parallel_world_size() == tensor_model_parallel_size
), ("tensor parallel group already initialized, but of unexpected size: "
f"{get_tensor_model_parallel_world_size()=} vs. "
f"{tensor_model_parallel_size=}")
if sequence_model_parallel_size > 1:
sp_world_size = get_sp_group().world_size
assert (sp_world_size == sequence_model_parallel_size), (
"sequence parallel group already initialized, but of unexpected size: "
f"{sp_world_size=} vs. "
f"{sequence_model_parallel_size=}")
def model_parallel_is_initialized() -> bool:
"""Check if tensor, sequence parallel groups are initialized."""
return _TP is not None and _SP is not None
_TP_STATE_PATCHED = False
@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
"""Patch the tp group temporarily until this function ends.
This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _TP_STATE_PATCHED
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
_TP_STATE_PATCHED = True
old_tp_group = get_tp_group()
global _TP
_TP = tp_group
try:
yield
finally:
# restore the original state
_TP_STATE_PATCHED = False
_TP = old_tp_group
def get_tensor_model_parallel_world_size() -> int:
"""Return world size for the tensor model parallel group."""
return get_tp_group().world_size
def get_tensor_model_parallel_rank() -> int:
"""Return my rank for the tensor model parallel group."""
return get_tp_group().rank_in_group
def destroy_model_parallel() -> None:
"""Set the groups to none and destroy them."""
global _TP
if _TP:
_TP.destroy()
_TP = None
global _SP
if _SP:
_SP.destroy()
_SP = None
def destroy_distributed_environment() -> None:
global _WORLD
if _WORLD:
_WORLD.destroy()
_WORLD = None
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
if shutdown_ray:
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
from fastvideo.v1.platforms import current_platform
if not current_platform.is_cpu():
torch.cuda.empty_cache()
try:
torch._C._host_emptyCache()
except AttributeError:
logger.warning(
"torch._C._host_emptyCache() only available in Pytorch >=2.5")
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
source_rank: int = 0) -> List[bool]:
"""
This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory).
"""
if isinstance(pg, ProcessGroup):
assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, (
"in_the_same_node_as should be tested with a non-NCCL group.")
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)
# global ranks of the processes in the group
ranks = torch.distributed.get_process_group_ranks(pg)
else:
rank = pg.rank
world_size = pg.world_size
ranks = list(range(world_size))
# local tensor in each process to store the result
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
magic_message = b"magic_message"
shm = None
try:
with contextlib.suppress(OSError):
if rank == source_rank:
# create a shared memory segment
shm = shared_memory.SharedMemory(create=True, size=128)
shm.buf[:len(magic_message)] = magic_message
if isinstance(pg, ProcessGroup):
torch.distributed.broadcast_object_list(
[shm.name], src=ranks[source_rank], group=pg)
else:
pg.broadcast_obj(shm.name, src=source_rank)
is_in_the_same_node[rank] = 1
else:
# try to open the shared memory segment
if isinstance(pg, ProcessGroup):
recv = [None]
torch.distributed.broadcast_object_list(
recv, src=ranks[source_rank], group=pg)
name = recv[0]
else:
name = pg.broadcast_obj(None, src=source_rank)
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch("multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None):
shm = shared_memory.SharedMemory(name=name)
if shm.buf[:len(magic_message)] == magic_message:
is_in_the_same_node[rank] = 1
except Exception as e:
logger.error("Error ignored in is_in_the_same_node: %s", e)
finally:
if shm:
shm.close()
if isinstance(pg, ProcessGroup):
torch.distributed.barrier(group=pg)
else:
pg.barrier()
# clean up the shared memory segment
with contextlib.suppress(OSError):
if rank == source_rank and shm:
shm.unlink()
if isinstance(pg, ProcessGroup):
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
aggregated_data = is_in_the_same_node
else:
aggregated_data = torch.zeros_like(is_in_the_same_node)
for i in range(world_size):
rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
aggregated_data += rank_data
return [x == 1 for x in aggregated_data.tolist()]
def initialize_tensor_parallel_group(
tensor_model_parallel_size: int = 1,
backend: Optional[str] = None,
group_name_suffix: str = "") -> GroupCoordinator:
"""Initialize a tensor parallel group for a specific model.
This function creates a tensor parallel group that can be used with the
patch_tensor_parallel_group context manager. It allows different models
to use different tensor parallelism configurations.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
backend: communication backend to use.
group_name_suffix: optional suffix to make the group name unique.
Returns:
A GroupCoordinator for tensor parallelism that can be used with
the patch_tensor_parallel_group context manager.
Example usage:
```python
# Initialize tensor parallel group for model1
tp_group_model1 = initialize_tensor_parallel_group(
tensor_model_parallel_size=4,
group_name_suffix="model1"
)
# Use tensor parallelism for model1
with patch_tensor_parallel_group(tp_group_model1):
# Run model1 with tensor parallelism
output1 = model1(input1)
```
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
# Ensure the world size is compatible with the parallelism configuration
assert world_size % tensor_model_parallel_size == 0, \
f"World size ({world_size}) must be divisible by tensor_model_parallel_size ({tensor_model_parallel_size})"
# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups: int = (world_size //
tensor_model_parallel_size)
tp_group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = list(
range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size))
tp_group_ranks.append(ranks)
# Create TP group coordinator with a unique name
group_name = f"tp_{group_name_suffix}" if group_name_suffix else "tp"
tp_group = init_model_parallel_group(tp_group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name=group_name)
return tp_group
def initialize_sequence_parallel_group(
sequence_model_parallel_size: int = 1,
backend: Optional[str] = None,
group_name_suffix: str = "") -> GroupCoordinator:
"""Initialize a sequence parallel group for a specific model.
This function creates a sequence parallel group that can be used with the
patch_sequence_parallel_group context manager. It allows different models
to use different sequence parallelism configurations.
Arguments:
sequence_model_parallel_size: number of GPUs used for sequence model parallelism.
backend: communication backend to use.
group_name_suffix: optional suffix to make the group name unique.
Returns:
A GroupCoordinator for sequence parallelism that can be used with
the patch_sequence_parallel_group context manager.
Example usage:
```python
# Initialize sequence parallel group for model2
sp_group_model2 = initialize_sequence_parallel_group(
sequence_model_parallel_size=2,
group_name_suffix="model2"
)
# Use sequence parallelism for model2
with patch_sequence_parallel_group(sp_group_model2):
# Run model2 with sequence parallelism
output2 = model2(input2)
```
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
# Ensure the world size is compatible with the parallelism configuration
assert world_size % sequence_model_parallel_size == 0, \
f"World size ({world_size}) must be divisible by sequence_model_parallel_size ({sequence_model_parallel_size})"
# Build the sequence model-parallel groups.
num_sequence_model_parallel_groups: int = (world_size //
sequence_model_parallel_size)
sp_group_ranks = []
for i in range(num_sequence_model_parallel_groups):
# Create groups of consecutive ranks
ranks = list(
range(i * sequence_model_parallel_size,
(i + 1) * sequence_model_parallel_size))
sp_group_ranks.append(ranks)
# Create SP group coordinator with a unique name
group_name = f"sp_{group_name_suffix}" if group_name_suffix else "sp"
sp_group = init_model_parallel_group(sp_group_ranks,
get_world_group().local_rank,
backend,
group_name=group_name)
return sp_group
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/utils.py
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import pickle
import time
from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
import torch
from torch.distributed import TCPStore
from fastvideo.v1.logger import init_logger
logger = init_logger(__name__)
def ensure_divisibility(numerator, denominator) -> None:
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator)
def divide(numerator: int, denominator: int) -> int:
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# NOTE: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tuple(tensor_list)
@dataclasses.dataclass
class StatelessProcessGroup:
"""A dataclass to hold a metadata store, and the rank, world_size of the
group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
"""
rank: int
world_size: int
store: torch._C._distributed_c10d.Store
data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
# src rank -> counter
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
broadcast_send_counter: int = 0
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(
default_factory=dict)
# A deque to store the data entries, with key and timestamp.
entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque)
def __post_init__(self):
assert self.rank < self.world_size
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
def send_obj(self, obj: Any, dst: int):
"""Send an object to a destination rank."""
self.expire_data()
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
def expire_data(self) -> None:
"""Expire data that is older than `data_expiration_seconds` seconds."""
while self.entries:
# check the oldest entry
key, timestamp = self.entries[0]
if time.time() - timestamp > self.data_expiration_seconds:
self.store.delete_key(key)
self.entries.popleft()
else:
break
def recv_obj(self, src: int) -> Any:
"""Receive an object from a source rank."""
obj = pickle.loads(
self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
self.recv_src_counter[src] += 1
return obj
def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
"""Broadcast an object from a source rank to all other ranks.
It does not clean up after all ranks have received the object.
Use it for limited times, e.g., for initialization.
"""
if self.rank == src:
self.expire_data()
key = (f"broadcast_from/{src}/"
f"{self.broadcast_send_counter}")
self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return obj
else:
key = (f"broadcast_from/{src}/"
f"{self.broadcast_recv_src_counter[src]}")
recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return recv_obj
def all_gather_obj(self, obj: Any) -> list[Any]:
"""All gather an object from all ranks."""
gathered_objs = []
for i in range(self.world_size):
if i == self.rank:
gathered_objs.append(obj)
self.broadcast_obj(obj, src=self.rank)
else:
recv_obj = self.broadcast_obj(None, src=i)
gathered_objs.append(recv_obj)
return gathered_objs
def barrier(self):
"""A barrier to synchronize all ranks."""
for i in range(self.world_size):
if i == self.rank:
self.broadcast_obj(None, src=self.rank)
else:
self.broadcast_obj(None, src=i)
@staticmethod
def create(
host: str,
port: int,
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
If we have process A and process B called `torch.distributed.init_process_group`
to form a group, and then we want to form another group with process A, B, C,
D, it is not possible in PyTorch, because process A and process B have already
formed a group, and process C and process D cannot join that group. This
function is a workaround for this issue.
`torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `StatelessProcessGroup` object that can be
used for exchanging metadata. With this function, process A and process B
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
store = TCPStore(
host_name=host,
port=port,
world_size=world_size,
is_master=(rank == 0),
)
return StatelessProcessGroup(
rank=rank,
world_size=world_size,
store=store,
data_expiration_seconds=data_expiration_seconds)
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/types.py
import argparse
from fastvideo.v1.utils import FlexibleArgumentParser
class CLISubcommand:
"""Base class for CLI subcommands"""
name: str
def cmd(self, args: argparse.Namespace) -> None:
"""Execute the command with the given arguments"""
raise NotImplementedError
def validate(self, args: argparse.Namespace) -> None:
"""Validate the arguments for this command"""
pass
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
"""Initialize the subparser for this command"""
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/serve.py
import argparse
from typing import List, cast
from fastvideo.v1.entrypoints.cli import utils
from fastvideo.v1.entrypoints.cli.cli_types import CLISubcommand
from fastvideo.v1.fastvideo_args import FastVideoArgs
from fastvideo.v1.utils import FlexibleArgumentParser
class GenerateSubcommand(CLISubcommand):
"""The `generate` subcommand for the FastVideo CLI"""
def __init__(self) -> None:
self.name = "generate"
super().__init__()
def cmd(self, args: argparse.Namespace) -> None:
excluded_args = [
'subparser', 'config', 'num_gpus', 'master_port',
'dispatch_function'
]
# Create a filtered dictionary of arguments
filtered_args = {
k: v
for k, v in vars(args).items()
if k not in excluded_args and v is not None
}
main_args = []
for key, value in filtered_args.items():
# Convert underscores to dashes in argument names
arg_name = f"--{key.replace('_', '-')}"
# Handle boolean flags
if isinstance(value, bool):
if value:
main_args.append(arg_name)
else:
main_args.append(arg_name)
main_args.append(str(value))
utils.launch_distributed(args.num_gpus,
main_args,
master_port=args.master_port)
def validate(self, args: argparse.Namespace) -> None:
if args.num_gpus is not None and args.num_gpus <= 0:
raise ValueError("Number of gpus must be positive")
if args.master_port is not None and (args.master_port < 1024
or args.master_port > 65535):
raise ValueError("Master port must be between 1024 and 65535")
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
generate_parser = subparsers.add_parser(
"generate",
help="Run inference on a model",
usage=
"fastvideo generate --model-path MODEL_PATH_OR_ID --prompt PROMPT [OPTIONS]"
)
generate_parser.add_argument(
"--config",
type=str,
default='',
required=False,
help="Read CLI options from a config YAML file.")
generate_parser.add_argument("--master-port",
type=int,
default=None,
help="Port for the master process")
generate_parser = FastVideoArgs.add_cli_args(generate_parser)
return cast(FlexibleArgumentParser, generate_parser)
def cmd_init() -> List[CLISubcommand]:
return [GenerateSubcommand()]
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/main.py
from typing import List
from fastvideo.v1.entrypoints.cli.cli_types import CLISubcommand
from fastvideo.v1.entrypoints.cli.generate import cmd_init as generate_cmd_init
from fastvideo.v1.utils import FlexibleArgumentParser
def cmd_init() -> List[CLISubcommand]:
"""Initialize all commands from separate modules"""
commands = []
commands.extend(generate_cmd_init())
return commands
def main() -> None:
parser = FlexibleArgumentParser(description="FastVideo CLI")
parser.add_argument('-v', '--version', action='version', version='0.1.0')
subparsers = parser.add_subparsers(required=False, dest="subparser")
cmds = {}
for cmd in cmd_init():
cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd)
cmds[cmd.name] = cmd
args = parser.parse_args()
if args.subparser in cmds:
cmds[args.subparser].validate(args)
if hasattr(args, "dispatch_function"):
args.dispatch_function(args)
else:
parser.print_help()
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
import os
import subprocess
import sys
from typing import List, Optional
from fastvideo.v1.logger import init_logger
logger = init_logger(__name__)
def launch_distributed(num_gpus: int,
args: List[str],
master_port: Optional[int] = None) -> int:
"""
Launch a distributed job with the given arguments
Args:
num_gpus: Number of GPUs to use
args: Arguments to pass to v1_fastvideo_inference.py (defaults to sys.argv[1:])
master_port: Port for the master process (default: random)
"""
current_env = os.environ.copy()
python_executable = sys.executable
project_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../.."))
main_script = os.path.join(project_root,
"fastvideo/v1/sample/v1_fastvideo_inference.py")
cmd = [
python_executable, "-m", "torch.distributed.run",
f"--nproc_per_node={num_gpus}"
]
if master_port is not None:
cmd.append(f"--master_port={master_port}")
cmd.append(main_script)
cmd.extend(args)
logger.info("Running inference with %d GPU(s)", num_gpus)
logger.info("Launching command: %s", " ".join(cmd))
current_env["PYTHONIOENCODING"] = "utf-8"
process = subprocess.Popen(cmd,
env=current_env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
bufsize=1,
encoding='utf-8',
errors='replace')
if process.stdout:
for line in iter(process.stdout.readline, ''):
print(line.strip())
return process.wait()
# SPDX-License-Identifier: Apache-2.0
"""
VideoGenerator module for FastVideo.
This module provides a consolidated interface for generating videos using
diffusion models.
"""
import os
import time
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Union
import imageio
import numpy as np
import torch
import torchvision
from einops import rearrange
from fastvideo.v1.configs.pipelines import (PipelineConfig,
get_pipeline_config_cls_for_name)
from fastvideo.v1.configs.sample import SamplingParam
from fastvideo.v1.fastvideo_args import FastVideoArgs
from fastvideo.v1.logger import init_logger
from fastvideo.v1.pipelines import ForwardBatch
from fastvideo.v1.utils import align_to, shallow_asdict
from fastvideo.v1.worker.executor import Executor
logger = init_logger(__name__)
class VideoGenerator:
"""
A unified class for generating videos using diffusion models.
This class provides a simple interface for video generation with rich
customization options, similar to popular frameworks like HF Diffusers.
"""
def __init__(self, fastvideo_args: FastVideoArgs,
executor_class: type[Executor], log_stats: bool):
"""
Initialize the video generator.
Args:
pipeline: The pipeline to use for inference
fastvideo_args: The inference arguments
"""
self.fastvideo_args = fastvideo_args
self.executor = executor_class(fastvideo_args)
@classmethod
def from_pretrained(cls,
model_path: str,
device: Optional[str] = None,
torch_dtype: Optional[torch.dtype] = None,
pipeline_config: Optional[
Union[str
| PipelineConfig]] = None,
**kwargs) -> "VideoGenerator":
"""
Create a video generator from a pretrained model.
Args:
model_path: Path or identifier for the pretrained model
device: Device to load the model on (e.g., "cuda", "cuda:0", "cpu")
torch_dtype: Data type for model weights (e.g., torch.float16)
**kwargs: Additional arguments to customize model loading
Returns:
The created video generator
Priority level: Default pipeline config < User's pipeline config < User's kwargs
"""
config = None
# 1. If users provide a pipeline config, it will override the default pipeline config
if isinstance(pipeline_config, PipelineConfig):
config = pipeline_config
else:
config_cls = get_pipeline_config_cls_for_name(model_path)
if config_cls is not None:
config = config_cls()
if isinstance(pipeline_config, str):
config.load_from_json(pipeline_config)
# 2. If users also provide some kwargs, it will override the pipeline config.
# The user kwargs shouldn't contain model config parameters!
if config is None:
logger.warning("No config found for model %s, using default config",
model_path)
config_args = kwargs
else:
config_args = shallow_asdict(config)
config_args.update(kwargs)
fastvideo_args = FastVideoArgs(
model_path=model_path,
device_str=device or "cuda" if torch.cuda.is_available() else "cpu",
**config_args)
fastvideo_args.check_fastvideo_args()
return cls.from_fastvideo_args(fastvideo_args)
@classmethod
def from_fastvideo_args(cls,
fastvideo_args: FastVideoArgs) -> "VideoGenerator":
"""
Create a video generator with the specified arguments.
Args:
fastvideo_args: The inference arguments
Returns:
The created video generator
"""
# Initialize distributed environment if needed
# initialize_distributed_and_parallelism(fastvideo_args)
executor_class = Executor.get_class(fastvideo_args)
return cls(
fastvideo_args=fastvideo_args,
executor_class=executor_class,
log_stats=False, # TODO: implement
)
def generate_video(
self,
prompt: str,
sampling_param: Optional[SamplingParam] = None,
**kwargs,
) -> Union[Dict[str, Any], List[np.ndarray]]:
"""
Generate a video based on the given prompt.
Args:
prompt: The prompt to use for generation
negative_prompt: The negative prompt to use (overrides the one in fastvideo_args)
output_path: Path to save the video (overrides the one in fastvideo_args)
save_video: Whether to save the video to disk
return_frames: Whether to return the raw frames
num_inference_steps: Number of denoising steps (overrides fastvideo_args)
guidance_scale: Classifier-free guidance scale (overrides fastvideo_args)
num_frames: Number of frames to generate (overrides fastvideo_args)
height: Height of generated video (overrides fastvideo_args)
width: Width of generated video (overrides fastvideo_args)
fps: Frames per second for saved video (overrides fastvideo_args)
seed: Random seed for generation (overrides fastvideo_args)
callback: Callback function called after each step
callback_steps: Number of steps between each callback
Returns:
Either the output dictionary or the list of frames depending on return_frames
"""
# Create a copy of inference args to avoid modifying the original
fastvideo_args = self.fastvideo_args
# Validate inputs
if not isinstance(prompt, str):
raise TypeError(
f"`prompt` must be a string, but got {type(prompt)}")
prompt = prompt.strip()
if sampling_param is None:
sampling_param = SamplingParam.from_pretrained(
fastvideo_args.model_path)
kwargs["prompt"] = prompt
sampling_param.update(kwargs)
# Process negative prompt
if sampling_param.negative_prompt is not None:
sampling_param.negative_prompt = sampling_param.negative_prompt.strip(
)
# Validate dimensions
if (sampling_param.height <= 0 or sampling_param.width <= 0
or sampling_param.num_frames <= 0):
raise ValueError(
f"Height, width, and num_frames must be positive integers, got "
f"height={sampling_param.height}, width={sampling_param.width}, "
f"num_frames={sampling_param.num_frames}")
if (
sampling_param.num_frames - 1
) % fastvideo_args.vae_config.arch_config.temporal_compression_ratio != 0:
raise ValueError(
f"num_frames-1 must be a multiple of {fastvideo_args.vae_config.arch_config.temporal_compression_ratio}, got {sampling_param.num_frames}"
)
# Calculate sizes
target_height = align_to(sampling_param.height, 16)
target_width = align_to(sampling_param.width, 16)
# Calculate latent sizes
latents_size = [(sampling_param.num_frames - 1) // 4 + 1,
sampling_param.height // 8, sampling_param.width // 8]
n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
# Log parameters
debug_str = f"""
height: {target_height}
width: {target_width}
video_length: {sampling_param.num_frames}
prompt: {prompt}
neg_prompt: {sampling_param.negative_prompt}
seed: {sampling_param.seed}
infer_steps: {sampling_param.num_inference_steps}
num_videos_per_prompt: {sampling_param.num_videos_per_prompt}
guidance_scale: {sampling_param.guidance_scale}
n_tokens: {n_tokens}
flow_shift: {fastvideo_args.flow_shift}
embedded_guidance_scale: {fastvideo_args.embedded_cfg_scale}"""
logger.info(debug_str)
# Prepare batch
batch = ForwardBatch(
**asdict(sampling_param),
eta=0.0,
n_tokens=n_tokens,
extra={},
)
# Run inference
start_time = time.time()
output_batch = self.executor.execute_forward(batch, fastvideo_args)
samples = output_batch
gen_time = time.time() - start_time
logger.info("Generated successfully in %.2f seconds", gen_time)
# Process outputs
videos = rearrange(samples, "b c t h w -> t b c h w")
frames = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=6)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
frames.append((x * 255).numpy().astype(np.uint8))
# Save video if requested
if batch.save_video:
save_path = batch.output_path
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
video_path = os.path.join(save_path, f"{prompt[:100]}.mp4")
imageio.mimsave(video_path, frames, fps=batch.fps)
logger.info("Saved video to %s", video_path)
else:
logger.warning("No output path provided, video not saved")
if batch.return_frames:
return frames
else:
return {
"samples": samples,
"prompts": prompt,
"size": (target_height, target_width, batch.num_frames),
"generation_time": gen_time
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment