Unverified Commit 7bc1dae0 authored by Mick's avatar Mick Committed by GitHub
Browse files

WIP: initial multimodal-gen support (#12484)


Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarJiLi <leege233@gmail.com>
Co-authored-by: default avatarCHEN Xi <78632976+RubiaCx@users.noreply.github.com>
Co-authored-by: default avatarlaixin <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarSolitaryThinker <wlsaidhi@gmail.com>
Co-authored-by: default avatarjzhang38 <a1286225768@gmail.com>
Co-authored-by: default avatarBrianChen1129 <yongqichcd@gmail.com>
Co-authored-by: default avatarKevin Lin <42618777+kevin314@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarrlsu9 <r3su@ucsd.edu>
Co-authored-by: default avatarJinzhe Pan <48981407+eigensystem@users.noreply.github.com>
Co-authored-by: default avatarforeverpiano <pianoqwz@qq.com>
Co-authored-by: default avatarRandNMR73 <notomatthew31@gmail.com>
Co-authored-by: default avatarPorridgeSwim <yz3883@columbia.edu>
Co-authored-by: default avatarJiali Chen <90408393+gary-chenjl@users.noreply.github.com>
parent 4fe53e58
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copyright 2024 xDiT team.
# Adapted from
# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py
# Copyright 2023 The vLLM team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import pickle
from collections import namedtuple
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed
from torch.cuda import synchronize
from torch.distributed import Backend, ProcessGroup
from sglang.multimodal_gen import envs
from sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase,
)
from sglang.multimodal_gen.runtime.distributed.device_communicators.cpu_communicator import (
CpuCommunicator,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
try:
import torch_musa # noqa: F401
from torch_musa.core.device import synchronize
except ModuleNotFoundError:
pass
logger = init_logger(__name__)
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
_group_name_counter: dict[str, int] = {}
def get_local_torch_device() -> torch.device:
"""Return the torch device for the current rank."""
from sglang.multimodal_gen.runtime.platforms import current_platform
return (
torch.device(f"cuda:{envs.LOCAL_RANK}")
if current_platform.is_cuda_alike()
else torch.device("mps")
)
def _get_unique_name(name: str) -> str:
"""Get a unique name for the group.
Example:
_get_unique_name("tp") -> "tp:0"
_get_unique_name("tp") -> "tp:1"
"""
if name not in _group_name_counter:
_group_name_counter[name] = 0
newname = f"{name}:{_group_name_counter[name]}"
_group_name_counter[name] += 1
return newname
def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = ""
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
metadata will be "key1%key2".
"""
metadata_list: List[Tuple[str, Any]] = []
tensor_list = []
for key, value in tensor_dict.items():
assert "%" not in key, (
"Avoid having '%' in key "
"as it is used as a separator for nested entries."
)
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = value.device.type
metadata_list.append(
(
prefix + key,
TensorMetadata(device, value.dtype, value.size()),
)
)
tensor_list.append(value)
elif isinstance(value, dict):
if len(value) == 0:
metadata_list.append((prefix + key, value))
inner_metadata_list, inner_tensor_list = _split_tensor_dict(
value, prefix + key + "%"
)
metadata_list.extend(inner_metadata_list)
tensor_list.extend(inner_tensor_list)
else:
metadata_list.append((prefix + key, value))
return metadata_list, tensor_list
def _update_nested_dict(nested_dict, flattened_key, value):
key_splits = flattened_key.split("%")
cur_dict = nested_dict
for k in key_splits[:-1]:
if k not in cur_dict:
cur_dict[k] = {}
cur_dict = cur_dict[k]
cur_dict[key_splits[-1]] = value
@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream | None
class GroupCoordinator:
"""
PyTorch ProcessGroup wrapper for a group of processes.
PyTorch ProcessGroup is bound to one specific communication backend,
e.g. NCCL, Gloo, MPI, etc.
GroupCoordinator takes charge of all the communication operations among
the processes in the group. It can route the communication to
a specific implementation (e.g. switch allreduce implementation
based on the tensor size and cuda graph mode).
"""
# available attributes:
rank: int # global rank
ranks: List[int] # global ranks in the group
world_size: int # size of the group
# difference between `local_rank` and `rank_in_group`:
# if we have a group of size 4 across two nodes:
# Process | Node | Rank | Local Rank | Rank in Group
# 0 | 0 | 0 | 0 | 0
# 1 | 0 | 1 | 1 | 1
# 2 | 1 | 2 | 0 | 2
# 3 | 1 | 3 | 1 | 3
local_rank: int # local rank in the current node, used to assign devices
rank_in_group: int # rank inside the group
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
use_device_communicator: bool # whether to use device communicator
device_communicator: DeviceCommunicatorBase # device communicator
def __init__(
self,
group_ranks: List[List[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_device_communicator: bool = True,
use_message_queue_broadcaster: bool = False,
group_name: str | None = None,
):
self.unique_name = _get_unique_name(group_name)
self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None
for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self.device_group = device_group
self.cpu_group = cpu_group
assert self.cpu_group is not None, f"{group_ranks=}, {local_rank=}"
assert self.device_group is not None
# TODO: fix it for other platforms
self.device = get_local_torch_device()
from sglang.multimodal_gen.runtime.platforms import current_platform
self.use_device_communicator = use_device_communicator
self.device_communicator: DeviceCommunicatorBase = None # type: ignore
if use_device_communicator and self.world_size > 1:
# Platform-aware device communicator selection
if current_platform.is_cuda_alike():
from sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator import (
CudaCommunicator,
)
self.device_communicator = CudaCommunicator(
cpu_group=self.cpu_group,
device=self.device,
device_group=self.device_group,
unique_name=self.unique_name,
)
else:
# For MPS and CPU, use the CPU communicator
self.device_communicator = CpuCommunicator(
cpu_group=self.cpu_group,
device=self.device,
device_group=self.device_group,
unique_name=self.unique_name,
)
self.mq_broadcaster = None
# TODO(will): check if this is needed
# self.use_custom_op_call = current_platform.is_cuda_alike()
self.use_custom_op_call = False
@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
return self.ranks[0]
@property
def last_rank(self):
"""Return the global rank of the last process in the group"""
return self.ranks[-1]
@property
def is_first_rank(self):
"""Return whether the caller is the first process in the group"""
return self.rank == self.first_rank
@property
def is_last_rank(self):
"""Return whether the caller is the last process in the group"""
return self.rank == self.last_rank
@property
def next_rank(self):
"""Return the global rank of the process that follows the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return self.ranks[(rank_in_group + 1) % world_size]
@property
def prev_rank(self):
"""Return the global rank of the process that precedes the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return self.ranks[(rank_in_group - 1) % world_size]
@property
def group_next_rank(self):
"""Return the group rank of the process that follows the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return (rank_in_group + 1) % world_size
@property
def group_prev_rank(self):
"""Return the group rank of the process that precedes the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return (rank_in_group - 1) % world_size
@property
def skip_rank(self):
"""Return the global rank of the process that skip connects with the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return self.ranks[(world_size - rank_in_group - 1) % world_size]
@property
def group_skip_rank(self):
"""Return the group rank of the process that skip connects with the caller"""
rank_in_group = self.rank_in_group
world_size = self.world_size
return (world_size - rank_in_group - 1) % world_size
@contextmanager
def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None):
# Platform-aware graph capture
from sglang.multimodal_gen.runtime.platforms import current_platform
if current_platform.is_cuda_alike():
if graph_capture_context is None:
stream = torch.cuda.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.cuda.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.cuda.stream(stream):
yield graph_capture_context
else:
# For non-CUDA platforms (MPS, CPU), just yield the context without stream management
if graph_capture_context is None:
# Create a dummy context for non-CUDA platforms
graph_capture_context = GraphCaptureContext(None)
yield graph_capture_context
def all_to_all_4D(
self, input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1
) -> torch.Tensor:
if self.world_size == 1:
return input_
return self.device_communicator.all_to_all_4D(input_, scatter_dim, gather_dim)
def all_reduce(
self,
input_: torch.Tensor,
op=torch._C._distributed_c10d.ReduceOp.SUM,
async_op: bool = False,
) -> torch.Tensor:
"""
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
value as the output.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
else:
torch.distributed.all_reduce(
input_, op=op, group=self.device_group, async_op=async_op
)
return input_
def all_gather(
self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False
) -> Union[torch.Tensor, List[torch.Tensor]]:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
input_size = list(input_.size())
input_size[0] *= world_size
output_tensor = torch.empty(
input_size, dtype=input_.dtype, device=input_.device
)
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
if dim != 0:
input_size[0] //= world_size
output_tensor = output_tensor.reshape(
[
world_size,
]
+ input_size
)
output_tensor = output_tensor.movedim(0, dim)
if separate_tensors:
tensor_list = [
output_tensor.reshape(-1)
.narrow(0, input_.numel() * i, input_.numel())
.view_as(input_)
for i in range(world_size)
]
return tensor_list
else:
input_size = list(input_.size())
input_size[dim] = input_size[dim] * world_size
# Reshape
output_tensor = output_tensor.reshape(input_size)
return output_tensor
def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(
input_, gather_list, dst=self.ranks[dst], group=self.device_group
)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def broadcast(self, input_: torch.Tensor, src: int = 0, async_op: bool = False):
"""Broadcast the input tensor.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
# Broadcast.
torch.distributed.broadcast(
input_,
src=self.ranks[src],
group=self.device_group,
async_op=async_op,
)
return input_
def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
"""Broadcast the input object.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return obj
if self.shm_broadcaster is not None:
assert src == 0, "Shared memory broadcaster only supports src=0"
return self.shm_broadcaster.broadcast_object(obj)
if self.rank_in_group == src:
torch.distributed.broadcast_object_list(
[obj], src=self.ranks[src], group=self.cpu_group
)
return obj
else:
recv = [None]
torch.distributed.broadcast_object_list(
recv, src=self.ranks[src], group=self.cpu_group
)
return recv[0]
def broadcast_object_list(
self,
obj_list: List[Any],
src: int = 0,
group: Optional[ProcessGroup] = None,
):
"""Broadcast the input object list.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return obj_list
# Broadcast.
torch.distributed.broadcast_object_list(
obj_list, src=self.ranks[src], group=self.device_group
)
return obj_list
def send_object(self, obj: Any, dst: int) -> None:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""
assert dst < self.world_size, f"Invalid dst rank ({dst})"
assert dst != self.rank, (
"Invalid destination rank. Destination rank is the same "
"as the current rank."
)
# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
size_tensor = torch.tensor(
[object_tensor.numel()], dtype=torch.long, device="cpu"
)
# Send object size
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
# Send object
torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
return None
def recv_object(self, src: int) -> Any:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
assert src < self.world_size, f"Invalid src rank ({src})"
assert (
src != self.rank
), "Invalid source rank. Source rank is the same as the current rank."
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
# Receive object size
rank_size = torch.distributed.recv(
size_tensor, src=self.ranks[src], group=self.cpu_group
)
# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device="cpu",
)
rank_object = torch.distributed.recv(
object_tensor, src=self.ranks[src], group=self.cpu_group
)
assert (
rank_object == rank_size
), "Received object sender rank does not match the size sender rank."
obj = pickle.loads(object_tensor.numpy().tobytes())
return obj
def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
group = self.device_group
metadata_group = self.cpu_group
assert src < self.world_size, f"Invalid src rank ({src})"
src = self.ranks[src]
rank = self.rank
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict, dict
), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.broadcast_object(metadata_list, src=src)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(
tensor, src=src, group=metadata_group, async_op=True
)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(
tensor, src=src, group=group, async_op=True
)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
else:
metadata_list = self.broadcast_object(None, src=src)
tensor_dict = {}
async_handles = []
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(
value.size, dtype=value.dtype, device=value.device
)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
_update_nested_dict(tensor_dict, key, tensor)
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(
tensor, src=src, group=metadata_group, async_op=True
)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(
tensor, src=src, group=group, async_op=True
)
async_handles.append(handle)
_update_nested_dict(tensor_dict, key, tensor)
else:
_update_nested_dict(tensor_dict, key, value)
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
def send_tensor_dict(
self,
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None,
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
group = self.device_group
metadata_group = self.cpu_group
if dst is None:
dst = self.group_next_rank
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict, dict
), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(
tensor, dst=self.ranks[dst], group=metadata_group
)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
return None
def recv_tensor_dict(
self, src: Optional[int] = None
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = self.group_prev_rank
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict: Dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
_update_nested_dict(tensor_dict, key, tensor)
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(
tensor, src=self.ranks[src], group=metadata_group
)
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=self.ranks[src], group=group)
_update_nested_dict(tensor_dict, key, tensor)
else:
_update_nested_dict(tensor_dict, key, value)
return tensor_dict
def barrier(self):
"""Barrier synchronization among the group.
NOTE: don't use `device_group` here! `barrier` in NCCL is
terrible because it is internally a broadcast operation with
secretly created GPU tensors. It is easy to mess up the current
device. Use the CPU group instead.
"""
torch.distributed.barrier(group=self.cpu_group)
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the rank_in_group of the destination rank."""
if dst is None:
dst = self.group_next_rank
torch.distributed.send(
tensor,
self.ranks[dst],
group=(
self.device_groups[self.rank_in_group % 2]
if self.world_size == 2
else self.device_group
),
)
def recv(
self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
) -> torch.Tensor:
"""Receives a tensor from the src rank."""
"""NOTE: `src` is the rank_in_group of the source rank."""
if src is None:
src = self.group_prev_rank
tensor = torch.empty(size, dtype=dtype, device=self.device)
torch.distributed.recv(
tensor,
self.ranks[src],
(
self.device_groups[(self.rank_in_group + 1) % 2]
if self.world_size == 2
else self.device_group
),
)
return tensor
def destroy(self) -> None:
if self.device_group is not None:
torch.distributed.destroy_process_group(self.device_group)
self.device_group = None
if self.cpu_group is not None:
torch.distributed.destroy_process_group(self.cpu_group)
self.cpu_group = None
if self.device_communicator is not None:
self.device_communicator.destroy()
if self.mq_broadcaster is not None:
self.mq_broadcaster = None
class PipelineGroupCoordinator(GroupCoordinator):
"""
available attributes:
rank: int # global rank
ranks: List[int] # global ranks in the group
world_size: int # size of the group
difference between `local_rank` and `rank_in_group`:
if we have a group of size 4 across two nodes:
Process | Node | Rank | Local Rank | Rank in Group
0 | 0 | 0 | 0 | 0
1 | 0 | 1 | 1 | 1
2 | 1 | 2 | 0 | 2
3 | 1 | 3 | 1 | 3
local_rank: int # local rank used to assign devices
rank_in_group: int # rank inside the group
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
"""
def __init__(
self,
group_ranks: List[List[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
group_name: str | None = None,
):
super().__init__(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=torch_distributed_backend,
group_name=group_name,
)
self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None
self.cpu_groups = []
self.device_groups = []
if len(group_ranks[0]) > 2 or len(group_ranks[0]) == 1:
for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self.device_group = device_group
self.cpu_group = cpu_group
# when pipeline parallelism is 2, we need to create two groups to avoid
# communication stall.
# *_group_0_1 represents the group for communication from device 0 to
# device 1.
# *_group_1_0 represents the group for communication from device 1 to
# device 0.
elif len(group_ranks[0]) == 2:
for ranks in group_ranks:
device_group_0_1 = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
device_group_1_0 = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group_0_1 = torch.distributed.new_group(ranks, backend="gloo")
cpu_group_1_0 = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self.device_groups = [device_group_0_1, device_group_1_0]
self.cpu_groups = [cpu_group_0_1, cpu_group_1_0]
self.device_group = device_group_0_1
self.cpu_group = cpu_group_0_1
assert self.cpu_group is not None
assert self.device_group is not None
self.device = envs.get_device(local_rank)
self.recv_buffer_set: bool = False
self.recv_tasks_queue: List[Tuple[str, int]] = []
self.receiving_tasks: List[Tuple[torch.distributed.Work, str, int]] = []
self.dtype: Optional[torch.dtype] = None
self.num_pipefusion_patches: Optional[int] = None
self.recv_shape: Dict[str, Dict[int, torch.Size]] = {}
self.send_shape: Dict[str, Dict[int, torch.Size]] = {}
self.recv_buffer: Dict[str, Dict[int, torch.Size]] = {}
self.skip_tensor_recv_buffer_set: bool = False
self.recv_skip_tasks_queue: List[Union[int, Tuple[str, int]]] = []
self.receiving_skip_tasks: List[Tuple[torch.distributed.Work, str, int]] = []
self.skip_tensor_recv_buffer: Optional[
Union[List[torch.Tensor], torch.Tensor]
] = None
self.skip_device_group = None
for ranks in group_ranks:
skip_device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend
)
if self.rank in ranks:
self.skip_device_group = skip_device_group
assert self.skip_device_group is not None
def reset_buffer(self):
self.recv_tasks_queue = []
self.receiving_tasks = []
self.recv_shape = {}
self.send_shape = {}
self.recv_buffer = {}
self.recv_skip_tasks_queue = []
self.receiving_skip_tasks = []
self.skip_tensor_recv_buffer = {}
def set_config(self, dtype: torch.dtype):
self.dtype = dtype
def set_recv_buffer(
self,
num_pipefusion_patches: int,
patches_shape_list: List[List[int]],
feature_map_shape: List[int],
dtype: torch.dtype,
):
assert isinstance(dtype, torch.dtype), "dtype must be a torch.dtype object"
assert (
isinstance(num_pipefusion_patches, int) and num_pipefusion_patches >= 1
), "num_pipefusion_patches must be greater than or equal to 1"
self.dtype = dtype
self.num_pipefusion_patches = num_pipefusion_patches
self.recv_buffer = [
torch.zeros(*shape, dtype=self.dtype, device=self.device)
for shape in patches_shape_list
]
self.recv_buffer.append(
torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device)
)
self.recv_buffer_set = True
def set_extra_tensors_recv_buffer(
self,
name: str,
shape: List[int],
num_buffers: int = 1,
dtype: torch.dtype = torch.float16,
):
self.extra_tensors_recv_buffer[name] = [
torch.zeros(*shape, dtype=dtype, device=self.device)
for _ in range(num_buffers)
]
def _check_shape_and_buffer(
self,
tensor_send_to_next=None,
recv_prev=False,
name: Optional[str] = None,
segment_idx: int = 0,
):
send_flag = False
name = name or "latent"
if tensor_send_to_next is not None:
shape_list = self.send_shape.get(name, None)
if shape_list is None:
self.send_shape[name] = {segment_idx: tensor_send_to_next.shape}
send_flag = True
elif shape_list.get(segment_idx, None) is None:
self.send_shape[name][segment_idx] = tensor_send_to_next.shape
send_flag = True
recv_flag = False
if recv_prev:
shape_list = self.recv_shape.get(name, None)
if shape_list is None:
recv_flag = True
elif shape_list.get(segment_idx, None) is None:
recv_flag = True
recv_prev_shape = self._communicate_shapes(
tensor_send_to_next=tensor_send_to_next if send_flag else None,
recv_prev=recv_flag,
)
if recv_flag:
if self.recv_shape.get(name, None) is None:
self.recv_shape[name] = {segment_idx: recv_prev_shape}
else:
self.recv_shape[name][segment_idx] = recv_prev_shape
if self.recv_buffer.get(name, None) is None:
self.recv_buffer[name] = {
segment_idx: torch.zeros(
recv_prev_shape, device=self.device, dtype=self.dtype
)
}
else:
if self.recv_buffer[name].get(segment_idx, None) is not None:
logger.warning(
f"Recv buffer [name: {name}, segment_idx: {segment_idx}] already exist. updating..."
)
self.recv_buffer[name][segment_idx] = torch.zeros(
recv_prev_shape, device=self.device, dtype=self.dtype
)
def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
"""
ops = []
if recv_prev:
recv_prev_dim_tensor = torch.empty(
(1), device=self.device, dtype=torch.int64
)
recv_prev_dim_op = torch.distributed.P2POp(
torch.distributed.irecv,
recv_prev_dim_tensor,
self.prev_rank,
self.device_group,
)
ops.append(recv_prev_dim_op)
if tensor_send_to_next is not None:
send_next_dim_tensor = torch.tensor(
tensor_send_to_next.dim(), device=self.device, dtype=torch.int64
)
send_next_dim_op = torch.distributed.P2POp(
torch.distributed.isend,
send_next_dim_tensor,
self.next_rank,
self.device_group,
)
ops.append(send_next_dim_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
# should take this out once the bug with batch_isend_irecv is resolved.
synchronize()
ops = []
recv_prev_shape_tensor = None
if recv_prev:
recv_prev_shape_tensor = torch.empty(
torch.Size(recv_prev_dim_tensor),
device=self.device,
dtype=torch.int64,
)
recv_prev_shape_op = torch.distributed.P2POp(
torch.distributed.irecv,
recv_prev_shape_tensor,
self.prev_rank,
self.device_group,
)
ops.append(recv_prev_shape_op)
if tensor_send_to_next is not None:
send_next_shape_tensor = torch.tensor(
tensor_send_to_next.size(),
device=self.device,
dtype=torch.int64,
)
send_next_shape_op = torch.distributed.P2POp(
torch.distributed.isend,
send_next_shape_tensor,
self.next_rank,
self.device_group,
)
ops.append(send_next_shape_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
synchronize()
recv_prev_shape = [0, 0, 0]
if recv_prev_shape_tensor is not None:
recv_prev_shape = recv_prev_shape_tensor
return torch.Size(recv_prev_shape)
def pipeline_send(
self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1
) -> None:
tensor = tensor.contiguous()
self._check_shape_and_buffer(
tensor_send_to_next=tensor, name=name, segment_idx=segment_idx
)
self._pipeline_isend(tensor).wait()
def pipeline_isend(
self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1
) -> None:
tensor = tensor.contiguous()
self._check_shape_and_buffer(
tensor_send_to_next=tensor, name=name, segment_idx=segment_idx
)
self._pipeline_isend(tensor)
def pipeline_recv(self, idx: int = -1, name: str = "latent") -> torch.Tensor:
name = name or "latent"
self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx)
self._pipeline_irecv(self.recv_buffer[name][idx]).wait()
return self.recv_buffer[name][idx]
def add_pipeline_recv_task(self, idx: int = -1, name: str = "latent"):
name = name or "latent"
self.recv_tasks_queue.append((name, idx))
def recv_next(self):
if len(self.recv_tasks_queue) == 0:
raise ValueError("No more tasks to receive")
elif len(self.recv_tasks_queue) > 0:
name, idx = self.recv_tasks_queue.pop(0)
self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx)
self.receiving_tasks.append(
(self._pipeline_irecv(self.recv_buffer[name][idx]), name, idx)
)
def get_pipeline_recv_data(
self, idx: int = -1, name: str = "latent"
) -> torch.Tensor:
assert (
len(self.receiving_tasks) > 0
), "No tasks to receive, call add_pipeline_recv_task first"
receiving_task = self.receiving_tasks.pop(0)
receiving_task[0].wait()
assert (
receiving_task[1] == name and receiving_task[2] == idx
), "Received tensor does not match the requested"
return self.recv_buffer[name][idx]
def _pipeline_irecv(self, tensor: torch.tensor):
return torch.distributed.irecv(
tensor,
src=self.prev_rank,
group=(
self.device_groups[(self.rank_in_group + 1) % 2]
if self.world_size == 2
else self.device_group
),
)
def _pipeline_isend(self, tensor: torch.tensor):
return torch.distributed.isend(
tensor,
dst=self.next_rank,
group=(
self.device_groups[self.rank_in_group % 2]
if self.world_size == 2
else self.device_group
),
)
def set_skip_tensor_recv_buffer(
self,
patches_shape_list: List[List[int]],
feature_map_shape: List[int],
):
self.skip_tensor_recv_buffer = [
torch.zeros(*shape, dtype=self.dtype, device=self.device)
for shape in patches_shape_list
]
self.skip_tensor_recv_buffer.append(
torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device)
)
self.skip_tensor_recv_buffer_set = True
def pipeline_send_skip(self, tensor: torch.Tensor) -> None:
tensor = tensor.contiguous()
self._pipeline_isend_skip(tensor).wait()
def pipeline_isend_skip(self, tensor: torch.Tensor) -> None:
tensor = tensor.contiguous()
self._pipeline_isend_skip(tensor)
def pipeline_recv_skip(self, idx: int = -1) -> torch.Tensor:
self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]).wait()
return self.skip_tensor_recv_buffer[idx]
def add_pipeline_recv_skip_task(self, idx: int = -1):
self.recv_skip_tasks_queue.append(idx)
def get_pipeline_recv_skip_data(self, idx: int = -1) -> torch.Tensor:
assert (
len(self.receiving_skip_tasks) > 0
), "No tasks to receive, call add_pipeline_recv_skip_task first"
receiving_skip_task = self.receiving_skip_tasks.pop(0)
receiving_skip_task[0].wait()
assert (
receiving_skip_task[2] == idx
), "Received tensor does not match the requested"
return self.skip_tensor_recv_buffer[idx]
def recv_skip_next(self):
if len(self.recv_skip_tasks_queue) == 0:
raise ValueError("No more tasks to receive")
elif len(self.recv_skip_tasks_queue) > 0:
task = self.recv_skip_tasks_queue.pop(0)
idx = task
self.receiving_skip_tasks.append(
(
self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]),
None,
idx,
)
)
def _pipeline_irecv_skip(self, tensor: torch.tensor):
return torch.distributed.irecv(
tensor, src=self.skip_rank, group=self.skip_device_group
)
def _pipeline_isend_skip(self, tensor: torch.tensor):
return torch.distributed.isend(
tensor, dst=self.skip_rank, group=self.skip_device_group
)
class SequenceParallelGroupCoordinator(GroupCoordinator):
def __init__(
self,
group_ranks: List[List[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
group_name: str | None = None,
**kwargs,
):
super().__init__(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=torch_distributed_backend,
group_name=group_name,
)
ulysses_group = kwargs.get("ulysses_group", None)
ring_group = kwargs.get("ring_group", None)
if ulysses_group is None:
raise RuntimeError(
f"Please pass argument 'ulysses_group' when calling init func of SequenceParallelGroupCoordinator"
)
if ring_group is None:
raise RuntimeError(
f"Please pass argument 'ring_group' when calling init func of SequenceParallelGroupCoordinator"
)
self.ulysses_group = ulysses_group
self.ring_group = ring_group
self.ulysses_world_size = torch.distributed.get_world_size(self.ulysses_group)
self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group)
self.ring_world_size = torch.distributed.get_world_size(self.ring_group)
self.ring_rank = torch.distributed.get_rank(self.ring_group)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/parallel_state.py
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Adapted from
# Copyright 2024 xDiT team.
# Adapted from
# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py
# Copyright 2023 The vLLM team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""sgl-diffusion distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:
- call `init_distributed_environment` to initialize the distributed environment.
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
initialize the model parallel groups.
- any code dealing with the distributed stuff
- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.
If you only need to use the distributed environment without model parallelism,
you can skip the model parallel initialization and destruction steps.
"""
import contextlib
import os
import weakref
from collections import namedtuple
from collections.abc import Callable
from contextlib import contextmanager
from multiprocessing import shared_memory
from typing import Any, List, Optional
from unittest.mock import patch
import torch
import torch.distributed
from torch.distributed import ProcessGroup
import sglang.multimodal_gen.envs as envs
from sglang.multimodal_gen.runtime.distributed.utils import StatelessProcessGroup
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from ..utils.distributed import RankGenerator
from .group_coordinator import (
GroupCoordinator,
PipelineGroupCoordinator,
SequenceParallelGroupCoordinator,
get_local_torch_device,
)
logger = init_logger(__name__)
_WORLD: Optional[GroupCoordinator] = None
_TP: Optional[GroupCoordinator] = None
_SP: Optional[SequenceParallelGroupCoordinator] = None
_PP: Optional[PipelineGroupCoordinator] = None
_CFG: Optional[GroupCoordinator] = None
_DP: Optional[GroupCoordinator] = None
_DIT: Optional[GroupCoordinator] = None
_VAE: Optional[GroupCoordinator] = None
logger = init_logger(__name__)
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict(
tensor_dict: dict[str, torch.Tensor | Any]
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list: list[tuple[str, Any]] = []
tensor_list: list[torch.Tensor] = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = value.device.type
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size()))
)
tensor_list.append(value)
else:
metadata_list.append((key, value))
return metadata_list, tensor_list
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
def _register_group(group: "GroupCoordinator") -> None:
_groups[group.unique_name] = weakref.ref(group)
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor)
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)
_WORLD: GroupCoordinator | None = None
_NODE: GroupCoordinator | None = None
def get_world_group() -> GroupCoordinator:
assert _WORLD is not None, "world group is not initialized"
return _WORLD
def init_world_group(
ranks: list[int], local_rank: int, backend: str
) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=[ranks],
local_rank=local_rank,
torch_distributed_backend=backend,
use_device_communicator=True,
group_name="world",
)
# xDiT
def init_parallel_group_coordinator(
group_ranks: List[List[int]],
local_rank: int,
backend: str,
parallel_mode: str,
**kwargs,
) -> GroupCoordinator:
"""
Returns a Group Coordinator for the given parallel mode
"""
assert parallel_mode in [
"data",
"pipeline",
"tensor",
"sequence",
"classifier_free_guidance",
], f"parallel_mode {parallel_mode} is not supported"
if parallel_mode == "pipeline":
return PipelineGroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
group_name="pp_group",
)
elif parallel_mode == "sequence":
return SequenceParallelGroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
group_name="sp_group",
**kwargs,
)
else:
# fallback to GroupCoordinator
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
group_name="cfg_group",
)
# def init_parallel_group_coordinator(
# group_ranks: list[list[int]],
# local_rank: int,
# backend: str,
# use_message_queue_broadcaster: bool = False,
# group_name: str | None = None,
# ) -> GroupCoordinator:
# return GroupCoordinator(
# group_ranks=group_ranks,
# local_rank=local_rank,
# torch_distributed_backend=backend,
# use_device_communicator=True,
# use_message_queue_broadcaster=use_message_queue_broadcaster,
# group_name=group_name,
# )
_TP: GroupCoordinator | None = None
def get_tp_group() -> GroupCoordinator:
assert _TP is not None, "tensor model parallel group is not initialized"
return _TP
_ENABLE_CUSTOM_ALL_REDUCE = True
def set_custom_all_reduce(enable: bool):
global _ENABLE_CUSTOM_ALL_REDUCE
_ENABLE_CUSTOM_ALL_REDUCE = enable
def init_distributed_environment(
world_size: int = 1,
rank: int = 0,
distributed_init_method: str = "env://",
local_rank: int = 0,
backend: str = "nccl",
device_id: torch.device | None = None,
):
# Determine the appropriate backend based on the platform
from sglang.multimodal_gen.runtime.platforms import current_platform
if backend == "nccl" and not current_platform.is_cuda_alike():
# Use gloo backend for non-CUDA platforms (MPS, CPU)
backend = "gloo"
logger.info("Using gloo backend for %s platform", current_platform.device_name)
logger.debug(
"world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s",
world_size,
rank,
local_rank,
distributed_init_method,
backend,
)
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
"distributed environment"
)
# For MPS, don't pass device_id as it doesn't support device indices
extra_args = {} if current_platform.is_mps() else dict(device_id=device_id)
torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank,
**extra_args,
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
if local_rank == -1:
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
if distributed_init_method == "env://":
local_rank = envs.LOCAL_RANK
else:
local_rank = rank
global _WORLD
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
else:
assert (
_WORLD.world_size == torch.distributed.get_world_size()
), "world group already initialized with a different world size"
_SP: GroupCoordinator | None = None
def get_sp_group() -> SequenceParallelGroupCoordinator:
assert _SP is not None, "pipeline model parallel group is not initialized"
return _SP
_DP: GroupCoordinator | None = None
def get_dp_group() -> GroupCoordinator:
assert _DP is not None, "data parallel group is not initialized"
return _DP
# xDiT
def initialize_model_parallel(
data_parallel_size: int = 1,
classifier_free_guidance_degree: int = 1,
sequence_parallel_degree: Optional[int] = None,
ulysses_degree: int = 1,
ring_degree: int = 1,
tensor_parallel_degree: int = 1,
pipeline_parallel_degree: int = 1,
vae_parallel_size: int = 0,
backend: Optional[str] = None,
) -> None:
"""
Initialize model parallel groups.
Arguments:
data_parallel_size: number of data parallelism groups.
classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG)
sequence_parallel_degree: number of GPUs used for sequence parallelism. sequence_parallel_degree = ulysses_degree * ring_degree
ulysses_degree: number of GPUs used for ulysses sequence parallelism.
ring_degree: number of GPUs used for ring sequence parallelism.
tensor_parallel_degree: number of GPUs used for tensor parallelism.
pipeline_parallel_degree: number of GPUs used for pipeline parallelism.
backend: distributed backend of pytorch collective comm.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 groups to parallelize the batch dim(dp), 2 groups to parallelize
split batch caused by CFG, and 2 GPUs to parallelize sequence.
dp_degree (2) * cfg_degree (2) * sp_degree (2) * pp_degree (2) = 16.
The present function will create 8 data-parallel groups,
8 CFG group, 8 pipeline-parallel group, and
8 sequence-parallel groups:
8 data-parallel groups:
[g0, g8], [g1, g9], [g2, g10], [g3, g11],
[g4, g12], [g5, g13], [g6, g14], [g7, g15]
8 CFG-parallel groups:
[g0, g4], [g1, g5], [g2, g6], [g3, g7],
[g8, g12], [g9, g13], [g10, g14], [g11, g15]
8 sequence-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7],
[g8, g9], [g10, g11], [g12, g13], [g14, g15]
8 pipeline-parallel groups:
[g0, g2], [g4, g6], [g8, g10], [g12, g14],
[g1, g3], [g5, g7], [g9, g11], [g13, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
if backend is None:
backend = envs.get_torch_distributed_backend()
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
dit_parallel_size = (
data_parallel_size
* classifier_free_guidance_degree
* sequence_parallel_degree
* pipeline_parallel_degree
* tensor_parallel_degree
)
if world_size < dit_parallel_size:
raise RuntimeError(
f"world_size ({world_size}) is less than "
f"tensor_parallel_degree ({tensor_parallel_degree}) x "
f"pipeline_parallel_degree ({pipeline_parallel_degree}) x"
f"sequence_parallel_degree ({sequence_parallel_degree}) x"
f"classifier_free_guidance_degree "
f"({classifier_free_guidance_degree}) x"
f"data_parallel_degree ({data_parallel_size})"
)
rank_generator: RankGenerator = RankGenerator(
tensor_parallel_degree,
sequence_parallel_degree,
pipeline_parallel_degree,
classifier_free_guidance_degree,
data_parallel_size,
"tp-sp-pp-cfg-dp",
)
global _DP
assert _DP is None, "data parallel group is already initialized"
_DP = init_parallel_group_coordinator(
group_ranks=rank_generator.get_ranks("dp"),
local_rank=get_world_group().local_rank,
backend=backend,
parallel_mode="data",
)
global _CFG
assert _CFG is None, "classifier_free_guidance group is already initialized"
_CFG = init_parallel_group_coordinator(
group_ranks=rank_generator.get_ranks("cfg"),
local_rank=get_world_group().local_rank,
backend=backend,
parallel_mode="classifier_free_guidance",
)
global _PP
assert _PP is None, "pipeline model parallel group is already initialized"
_PP = init_parallel_group_coordinator(
group_ranks=rank_generator.get_ranks("pp"),
local_rank=get_world_group().local_rank,
backend=backend,
parallel_mode="pipeline",
)
global _SP
assert _SP is None, "sequence parallel group is already initialized"
from yunchang import set_seq_parallel_pg
from yunchang.globals import PROCESS_GROUP
set_seq_parallel_pg(
sp_ulysses_degree=ulysses_degree,
sp_ring_degree=ring_degree,
rank=get_world_group().rank_in_group,
world_size=dit_parallel_size,
)
_SP = init_parallel_group_coordinator(
group_ranks=rank_generator.get_ranks("sp"),
local_rank=get_world_group().local_rank,
backend=backend,
parallel_mode="sequence",
ulysses_group=PROCESS_GROUP.ULYSSES_PG,
ring_group=PROCESS_GROUP.RING_PG,
)
global _TP
assert _TP is None, "Tensor parallel group is already initialized"
_TP = init_parallel_group_coordinator(
group_ranks=rank_generator.get_ranks("tp"),
local_rank=get_world_group().local_rank,
backend=backend,
parallel_mode="tensor",
)
if vae_parallel_size > 0:
init_vae_group(dit_parallel_size, vae_parallel_size, backend)
init_dit_group(dit_parallel_size, backend)
#
# def initialize_model_parallel(
# tensor_model_parallel_size: int = 1,
# sequence_model_parallel_size: int = 1,
# data_parallel_size: int = 1,
# backend: str | None = None,
# ) -> None:
# """
# Initialize model parallel groups.
#
# Arguments:
# tensor_model_parallel_size: number of GPUs used for tensor model
# parallelism (used for language encoder).
# sequence_model_parallel_size: number of GPUs used for sequence model
# parallelism (used for DiT).
# """
# # Get world size and rank. Ensure some consistencies.
# assert (
# _WORLD is not None
# ), "world group is not initialized, please call init_distributed_environment first"
# world_size: int = get_world_size()
# backend = backend or torch.distributed.get_backend(get_world_group().device_group)
# assert (
# world_size >= tensor_model_parallel_size
# ), f"world_size({world_size}) must be greater than or equal to tensor_model_parallel_size({tensor_model_parallel_size})"
# num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
# global _TP
# assert _TP is None, "tensor model parallel group is already initialized"
# group_ranks = []
# for i in range(num_tensor_model_parallel_groups):
# ranks = list(
# range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
# )
# group_ranks.append(ranks)
#
# # message queue broadcaster is only used in tensor model parallel group
# _TP = init_parallel_group_coordinator(
# group_ranks,
# get_world_group().local_rank,
# backend,
# use_message_queue_broadcaster=True,
# group_name="tp",
# )
#
# # Build the sequence model-parallel groups.
# num_sequence_model_parallel_groups: int = world_size // sequence_model_parallel_size
# global _SP
# assert _SP is None, "sequence model parallel group is already initialized"
# group_ranks = []
#
# # Since SP is incompatible with TP and PP, we can use a simpler group creation logic
# for i in range(num_sequence_model_parallel_groups):
# # Create groups of consecutive ranks
# ranks = list(
# range(
# i * sequence_model_parallel_size, (i + 1) * sequence_model_parallel_size
# )
# )
# group_ranks.append(ranks)
#
# _SP = init_parallel_group_coordinator(
# group_ranks, get_world_group().local_rank, backend, group_name="sp"
# )
#
# # Build the data parallel groups.
# num_data_parallel_groups: int = sequence_model_parallel_size
# global _DP
# assert _DP is None, "data parallel group is already initialized"
# group_ranks = []
#
# for i in range(num_data_parallel_groups):
# ranks = list(range(i, world_size, num_data_parallel_groups))
# group_ranks.append(ranks)
#
# _DP = init_parallel_group_coordinator(
# group_ranks, get_world_group().local_rank, backend, group_name="dp"
# )
#
def get_sp_world_size() -> int:
"""Return world size for the sequence model parallel group."""
return get_sp_group().world_size
def get_sp_parallel_rank() -> int:
"""Return my rank for the sequence model parallel group."""
return get_sp_group().rank_in_group
def get_world_size() -> int:
"""Return world size for the world group."""
return get_world_group().world_size
def get_world_rank() -> int:
"""Return my rank for the world group."""
return get_world_group().rank
def get_dp_world_size() -> int:
"""Return world size for the data parallel group."""
return get_dp_group().world_size
def get_dp_rank() -> int:
"""Return my rank for the data parallel group."""
return get_dp_group().rank_in_group
def maybe_init_distributed_environment_and_model_parallel(
tp_size: int,
sp_size: int,
enable_cfg_parallel: bool,
ulysses_degree: int = 1,
ring_degree: int = 1,
dp_size: int = 1,
distributed_init_method: str = "env://",
):
from sglang.multimodal_gen.runtime.platforms import current_platform
if _WORLD is not None and model_parallel_is_initialized():
# make sure the tp and sp sizes are correct
assert (
get_tp_world_size() == tp_size
), f"You are trying to initialize model parallel groups with size {tp_size}, but they are already initialized with size {get_tp_world_size()}"
assert (
get_sp_world_size() == sp_size
), f"You are trying to initialize model parallel groups with size {sp_size}, but they are already initialized with size {get_sp_world_size()}"
return
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
rank = int(os.environ.get("RANK", 0))
device = get_local_torch_device()
logger.info(
"Initializing distributed environment with world_size=%d, device=%s",
world_size,
device,
main_process_only=False,
)
init_distributed_environment(
world_size=world_size,
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method,
device_id=device,
)
initialize_model_parallel(
data_parallel_size=dp_size,
classifier_free_guidance_degree=2 if enable_cfg_parallel else 1,
tensor_parallel_degree=tp_size,
ulysses_degree=ulysses_degree,
ring_degree=ring_degree,
sequence_parallel_degree=sp_size,
)
# Only set CUDA device if we're on a CUDA platform
if current_platform.is_cuda_alike():
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
def model_parallel_is_initialized() -> bool:
"""Check if tensor, sequence parallel groups are initialized."""
return _TP is not None and _SP is not None and _DP is not None and _CFG is not None
_TP_STATE_PATCHED = False
@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
"""Patch the tp group temporarily until this function ends.
This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _TP_STATE_PATCHED
assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
_TP_STATE_PATCHED = True
old_tp_group = get_tp_group()
global _TP
_TP = tp_group
try:
yield
finally:
# restore the original state
_TP_STATE_PATCHED = False
_TP = old_tp_group
def get_tp_world_size() -> int:
"""Return world size for the tensor model parallel group."""
return get_tp_group().world_size
def get_tp_rank() -> int:
"""Return my rank for the tensor model parallel group."""
return get_tp_group().rank_in_group
def destroy_distributed_environment() -> None:
global _WORLD
if _WORLD:
_WORLD.destroy()
_WORLD = None
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
if shutdown_ray:
import ray # Lazy import Ray
ray.shutdown()
def is_the_same_node_as(
pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0
) -> list[int]:
"""
This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory).
"""
if isinstance(pg, ProcessGroup):
assert (
torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL
), "in_the_same_node_as should be tested with a non-NCCL group."
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)
# global ranks of the processes in the group
ranks = torch.distributed.get_process_group_ranks(pg)
else:
rank = pg.rank
world_size = pg.world_size
ranks = list(range(world_size))
# local tensor in each process to store the result
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
magic_message = b"magic_message"
shm = None
try:
with contextlib.suppress(OSError):
if rank == source_rank:
# create a shared memory segment
shm = shared_memory.SharedMemory(create=True, size=128)
shm.buf[: len(magic_message)] = magic_message
if isinstance(pg, ProcessGroup):
torch.distributed.broadcast_object_list(
[shm.name], src=ranks[source_rank], group=pg
)
else:
pg.broadcast_obj(shm.name, src=source_rank)
is_in_the_same_node[rank] = 1
else:
# try to open the shared memory segment
if isinstance(pg, ProcessGroup):
recv = [None]
torch.distributed.broadcast_object_list(
recv, src=ranks[source_rank], group=pg
)
name = recv[0]
else:
name = pg.broadcast_obj(None, src=source_rank)
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch(
"multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None,
):
shm = shared_memory.SharedMemory(name=name)
if shm.buf[: len(magic_message)] == magic_message:
is_in_the_same_node[rank] = 1
except Exception as e:
logger.error("Error ignored in is_in_the_same_node: %s", e)
finally:
if shm:
shm.close()
if isinstance(pg, ProcessGroup):
torch.distributed.barrier(group=pg)
else:
pg.barrier()
# clean up the shared memory segment
with contextlib.suppress(OSError):
if rank == source_rank and shm:
shm.unlink()
if isinstance(pg, ProcessGroup):
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
aggregated_data = is_in_the_same_node
else:
aggregated_data = torch.zeros_like(is_in_the_same_node)
for i in range(world_size):
rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
aggregated_data += rank_data
return [x == 1 for x in aggregated_data.tolist()]
def initialize_tensor_parallel_group(
tensor_model_parallel_size: int = 1,
backend: str | None = None,
group_name_suffix: str = "",
) -> GroupCoordinator:
"""Initialize a tensor parallel group for a specific model.
This function creates a tensor parallel group that can be used with the
patch_tensor_parallel_group context manager. It allows different models
to use different tensor parallelism configurations.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
backend: communication backend to use.
group_name_suffix: optional suffix to make the group name unique.
Returns:
A GroupCoordinator for tensor parallelism that can be used with
the patch_tensor_parallel_group context manager.
Example usage:
```python
# Initialize tensor parallel group for model1
tp_group_model1 = initialize_tensor_parallel_group(
tensor_model_parallel_size=4,
group_name_suffix="model1"
)
# Use tensor parallelism for model1
with patch_tensor_parallel_group(tp_group_model1):
# Run model1 with tensor parallelism
output1 = model1(input1)
```
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
# Ensure the world size is compatible with the parallelism configuration
assert (
world_size % tensor_model_parallel_size == 0
), f"World size ({world_size}) must be divisible by tensor_model_parallel_size ({tensor_model_parallel_size})"
# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
tp_group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = list(
range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
)
tp_group_ranks.append(ranks)
# Create TP group coordinator with a unique name
group_name = f"tp_{group_name_suffix}" if group_name_suffix else "tp"
tp_group = init_parallel_group_coordinator(
tp_group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name=group_name,
)
return tp_group
def initialize_sequence_parallel_group(
sequence_model_parallel_size: int = 1,
backend: str | None = None,
group_name_suffix: str = "",
) -> GroupCoordinator:
"""Initialize a sequence parallel group for a specific model.
This function creates a sequence parallel group that can be used with the
patch_sequence_parallel_group context manager. It allows different models
to use different sequence parallelism configurations.
Arguments:
sequence_model_parallel_size: number of GPUs used for sequence model parallelism.
backend: communication backend to use.
group_name_suffix: optional suffix to make the group name unique.
Returns:
A GroupCoordinator for sequence parallelism that can be used with
the patch_sequence_parallel_group context manager.
Example usage:
```python
# Initialize sequence parallel group for model2
sp_group_model2 = initialize_sequence_parallel_group(
sequence_model_parallel_size=2,
group_name_suffix="model2"
)
# Use sequence parallelism for model2
with patch_sequence_parallel_group(sp_group_model2):
# Run model2 with sequence parallelism
output2 = model2(input2)
```
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
# Ensure the world size is compatible with the parallelism configuration
assert (
world_size % sequence_model_parallel_size == 0
), f"World size ({world_size}) must be divisible by sequence_model_parallel_size ({sequence_model_parallel_size})"
# Build the sequence model-parallel groups.
num_sequence_model_parallel_groups: int = world_size // sequence_model_parallel_size
sp_group_ranks = []
for i in range(num_sequence_model_parallel_groups):
# Create groups of consecutive ranks
ranks = list(
range(
i * sequence_model_parallel_size, (i + 1) * sequence_model_parallel_size
)
)
sp_group_ranks.append(ranks)
# Create SP group coordinator with a unique name
group_name = f"sp_{group_name_suffix}" if group_name_suffix else "sp"
sp_group = init_parallel_group_coordinator(
sp_group_ranks, get_world_group().local_rank, backend, group_name=group_name
)
return sp_group
# * QUERY
def get_world_group() -> GroupCoordinator:
assert _WORLD is not None, "world group is not initialized"
return _WORLD
# TP
def get_tp_group() -> GroupCoordinator:
assert _TP is not None, "tensor model parallel group is not initialized"
return _TP
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return get_tp_group().world_size
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return get_tp_group().rank_in_group
def get_sequence_parallel_world_size():
"""Return world size for the sequence parallel group."""
return get_sp_group().world_size
def get_sequence_parallel_rank():
"""Return my rank for the sequence parallel group."""
return get_sp_group().rank_in_group
def get_ulysses_parallel_world_size():
return get_sp_group().ulysses_world_size
def get_ulysses_parallel_rank():
return get_sp_group().ulysses_rank
def get_ring_parallel_world_size():
return get_sp_group().ring_world_size
def get_ring_parallel_rank():
return get_sp_group().ring_rank
# PP
def get_pp_group() -> PipelineGroupCoordinator:
assert _PP is not None, "pipeline model parallel group is not initialized"
return _PP
def get_pipeline_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
return get_pp_group().world_size
def get_pipeline_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
return get_pp_group().rank_in_group
def is_pipeline_first_stage():
"""Return True if in the first pipeline model parallel stage, False otherwise."""
return get_pipeline_parallel_rank() == 0
def is_pipeline_last_stage():
"""Return True if in the last pipeline model parallel stage, False otherwise."""
return get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1)
# CFG
def get_cfg_group() -> GroupCoordinator:
assert (
_CFG is not None
), "classifier_free_guidance parallel group is not initialized"
return _CFG
def get_classifier_free_guidance_world_size():
"""Return world size for the classifier_free_guidance parallel group."""
return get_cfg_group().world_size
def get_classifier_free_guidance_rank():
"""Return my rank for the classifier_free_guidance parallel group."""
return get_cfg_group().rank_in_group
# DP
def get_dp_group() -> GroupCoordinator:
assert _DP is not None, "pipeline model parallel group is not initialized"
return _DP
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return get_dp_group().world_size
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return get_dp_group().rank_in_group
def is_dp_last_group():
"""Return True if in the last data parallel group, False otherwise."""
return (
get_sequence_parallel_rank() == (get_sequence_parallel_world_size() - 1)
and get_classifier_free_guidance_rank()
== (get_classifier_free_guidance_world_size() - 1)
and get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1)
)
def get_dit_world_size():
"""Return world size for the DiT model (excluding VAE)."""
return (
get_data_parallel_world_size()
* get_classifier_free_guidance_world_size()
* get_sequence_parallel_world_size()
* get_pipeline_parallel_world_size()
* get_tensor_model_parallel_world_size()
)
# Add VAE getter functions
def get_vae_parallel_group() -> GroupCoordinator:
assert _VAE is not None, "VAE parallel group is not initialized"
return _VAE
def get_vae_parallel_world_size():
"""Return world size for the VAE parallel group."""
return get_vae_parallel_group().world_size
def get_vae_parallel_rank():
"""Return my rank for the VAE parallel group."""
return get_vae_parallel_group().rank_in_group
# * SET
def init_world_group(
ranks: List[int], local_rank: int, backend: str
) -> GroupCoordinator:
return GroupCoordinator(
group_ranks=[ranks],
local_rank=local_rank,
torch_distributed_backend=backend,
)
def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return (
_DP is not None
and _CFG is not None
and _SP is not None
and _PP is not None
and _TP is not None
)
def init_dit_group(
dit_parallel_size: int,
backend: str,
):
global _DIT
_DIT = torch.distributed.new_group(
ranks=list(range(dit_parallel_size)), backend=backend
)
def get_dit_group():
assert _DIT is not None, "DIT group is not initialized"
return _DIT
def init_vae_group(
dit_parallel_size: int,
vae_parallel_size: int,
backend: str,
):
# Initialize VAE group first
global _VAE
assert _VAE is None, "VAE parallel group is already initialized"
vae_ranks = list(range(dit_parallel_size, dit_parallel_size + vae_parallel_size))
_VAE = torch.distributed.new_group(ranks=vae_ranks, backend=backend)
def destroy_model_parallel() -> None:
"""Set the groups to none and destroy them."""
global _TP
if _TP:
_TP.destroy()
_TP = None
global _SP
if _SP:
_SP.destroy()
_SP = None
global _DP
if _DP:
_DP.destroy()
_DP = None
# xDit
# def destroy_model_parallel():
# """Set the groups to none and destroy them."""
# global _DP
# if _DP:
# _DP.destroy()
# _DP = None
#
# global _CFG
# if _CFG:
# _CFG.destroy()
# _CFG = None
#
# global _SP
# if _SP:
# _SP.destroy()
# _SP = None
#
# global _TP
# if _TP:
# _TP.destroy()
# _TP = None
#
# global _PP
# if _PP:
# _PP.destroy()
# _PP = None
#
# global _VAE
# if _VAE:
# _VAE.destroy()
# _VAE = None
def destroy_distributed_environment():
global _WORLD
if _WORLD:
_WORLD.destroy()
_WORLD = None
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/utils.py
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import pickle
import time
from collections import deque
from collections.abc import Sequence
from typing import Any
import torch
from torch.distributed import TCPStore
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
def ensure_divisibility(numerator, denominator) -> None:
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator: int, denominator: int) -> int:
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# NOTE: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tuple(tensor_list)
@dataclasses.dataclass
class StatelessProcessGroup:
"""A dataclass to hold a metadata store, and the rank, world_size of the
group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
"""
rank: int
world_size: int
store: torch._C._distributed_c10d.Store
data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter
send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
# src rank -> counter
recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
broadcast_send_counter: int = 0
broadcast_recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
# A deque to store the data entries, with key and timestamp.
entries: deque[tuple[str, float]] = dataclasses.field(default_factory=deque)
def __post_init__(self):
assert self.rank < self.world_size
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
def send_obj(self, obj: Any, dst: int):
"""Send an object to a destination rank."""
self.expire_data()
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.perf_counter()))
def expire_data(self) -> None:
"""Expire data that is older than `data_expiration_seconds` seconds."""
while self.entries:
# check the oldest entry
key, timestamp = self.entries[0]
if time.perf_counter() - timestamp > self.data_expiration_seconds:
self.store.delete_key(key)
self.entries.popleft()
else:
break
def recv_obj(self, src: int) -> Any:
"""Receive an object from a source rank."""
obj = pickle.loads(
self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
)
self.recv_src_counter[src] += 1
return obj
def broadcast_obj(self, obj: Any | None, src: int) -> Any:
"""Broadcast an object from a source rank to all other ranks.
It does not clean up after all ranks have received the object.
Use it for limited times, e.g., for initialization.
"""
if self.rank == src:
self.expire_data()
key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1
self.entries.append((key, time.perf_counter()))
return obj
else:
key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return recv_obj
def all_gather_obj(self, obj: Any) -> list[Any]:
"""All gather an object from all ranks."""
gathered_objs = []
for i in range(self.world_size):
if i == self.rank:
gathered_objs.append(obj)
self.broadcast_obj(obj, src=self.rank)
else:
recv_obj = self.broadcast_obj(None, src=i)
gathered_objs.append(recv_obj)
return gathered_objs
def barrier(self):
"""A barrier to synchronize all ranks."""
for i in range(self.world_size):
if i == self.rank:
self.broadcast_obj(None, src=self.rank)
else:
self.broadcast_obj(None, src=i)
@staticmethod
def create(
host: str,
port: int,
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
If we have process A and process B called `torch.distributed.init_process_group`
to form a group, and then we want to form another group with process A, B, C,
D, it is not possible in PyTorch, because process A and process B have already
formed a group, and process C and process D cannot join that group. This
function is a workaround for this issue.
`torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `StatelessProcessGroup` object that can be
used for exchanging metadata. With this function, process A and process B
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
store = TCPStore(
host_name=host,
port=port,
world_size=world_size,
is_master=(rank == 0),
)
return StatelessProcessGroup(
rank=rank,
world_size=world_size,
store=store,
data_expiration_seconds=data_expiration_seconds,
)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/types.py
import argparse
from sglang.multimodal_gen.utils import FlexibleArgumentParser
class CLISubcommand:
"""Base class for CLI subcommands"""
name: str
def cmd(self, args: argparse.Namespace) -> None:
"""Execute the command with the given arguments"""
raise NotImplementedError
def validate(self, args: argparse.Namespace) -> None:
"""Validate the arguments for this command"""
pass
def subparser_init(
self, subparsers: argparse._SubParsersAction
) -> FlexibleArgumentParser:
"""Initialize the subparser for this command"""
raise NotImplementedError
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/serve.py
import argparse
import dataclasses
import os
from typing import cast
from sglang.multimodal_gen import DiffGenerator
from sglang.multimodal_gen.configs.sample.base import (
SamplingParams,
generate_request_id,
)
from sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand
from sglang.multimodal_gen.runtime.entrypoints.cli.utils import (
RaiseNotImplementedAction,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import FlexibleArgumentParser
logger = init_logger(__name__)
def add_multimodal_gen_generate_args(parser: argparse.ArgumentParser):
"""Add the arguments for the generate command."""
parser.add_argument(
"--config",
type=str,
default="",
required=False,
help="Read CLI options from a config JSON or YAML file. If provided, --model-path and --prompt are optional.",
)
parser = ServerArgs.add_cli_args(parser)
parser = SamplingParams.add_cli_args(parser)
parser.add_argument(
"--text-encoder-configs",
action=RaiseNotImplementedAction,
help="JSON array of text encoder configurations (NOT YET IMPLEMENTED)",
)
return parser
def generate_cmd(args: argparse.Namespace):
"""The entry point for the generate command."""
# FIXME(mick): do not hard code
args.request_id = generate_request_id()
server_args = ServerArgs.from_cli_args(args)
sampling_params = SamplingParams.from_cli_args(args)
sampling_params.request_id = generate_request_id()
generator = DiffGenerator.from_pretrained(
model_path=server_args.model_path, server_args=server_args
)
generator.generate(prompt=sampling_params.prompt, sampling_params=sampling_params)
class GenerateSubcommand(CLISubcommand):
"""The `generate` subcommand for the sgl-diffusion CLI"""
def __init__(self) -> None:
self.name = "generate"
super().__init__()
self.init_arg_names = self._get_init_arg_names()
self.generation_arg_names = self._get_generation_arg_names()
def _get_init_arg_names(self) -> list[str]:
"""Get names of arguments for DiffGenerator initialization"""
return ["num_gpus", "tp_size", "sp_size", "model_path"]
def _get_generation_arg_names(self) -> list[str]:
"""Get names of arguments for generate_video method"""
return [field.name for field in dataclasses.fields(SamplingParams)]
def cmd(self, args: argparse.Namespace) -> None:
generate_cmd(args)
def validate(self, args: argparse.Namespace) -> None:
"""Validate the arguments for this command"""
if args.num_gpus is not None and args.num_gpus <= 0:
raise ValueError("Number of gpus must be positive")
if args.config and not os.path.exists(args.config):
raise ValueError(f"Config file not found: {args.config}")
def subparser_init(
self, subparsers: argparse._SubParsersAction
) -> FlexibleArgumentParser:
generate_parser = subparsers.add_parser(
"generate",
help="Run inference on a model",
usage="sgl_diffusion generate (--model-path MODEL_PATH_OR_ID --prompt PROMPT) | --config CONFIG_FILE [OPTIONS]",
)
generate_parser = add_multimodal_gen_generate_args(generate_parser)
return cast(FlexibleArgumentParser, generate_parser)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/main.py
from sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand
from sglang.multimodal_gen.runtime.entrypoints.cli.generate import GenerateSubcommand
from sglang.multimodal_gen.runtime.entrypoints.cli.serve import ServeSubcommand
from sglang.multimodal_gen.utils import FlexibleArgumentParser
def generate_cmd_init() -> list[CLISubcommand]:
return [GenerateSubcommand(), ServeSubcommand()]
def cmd_init() -> list[CLISubcommand]:
"""Initialize all commands from separate modules"""
commands = []
commands.extend(generate_cmd_init())
return commands
def main() -> None:
parser = FlexibleArgumentParser(description="sgl-diffusion CLI")
parser.add_argument("-v", "--version", action="version", version="0.1.0")
subparsers = parser.add_subparsers(required=False, dest="subparser")
cmds = {}
for cmd in cmd_init():
cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd)
cmds[cmd.name] = cmd
args = parser.parse_args()
if args.subparser in cmds:
cmds[args.subparser].validate(args)
if hasattr(args, "dispatch_function"):
args.dispatch_function(args)
else:
parser.print_help()
if __name__ == "__main__":
main()
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
from typing import cast
from sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand
from sglang.multimodal_gen.runtime.launch_server import launch_server
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import FlexibleArgumentParser
logger = init_logger(__name__)
def add_multimodal_gen_serve_args(parser: argparse.ArgumentParser):
"""Add the arguments for the serve command."""
parser.add_argument(
"--config",
type=str,
default="",
required=False,
help="Read CLI options from a config JSON or YAML file.",
)
return ServerArgs.add_cli_args(parser)
def execute_serve_cmd(args: argparse.Namespace, unknown_args: list[str] | None = None):
"""The entry point for the serve command."""
server_args = ServerArgs.from_cli_args(args, unknown_args)
server_args.post_init_serve()
launch_server(server_args)
class ServeSubcommand(CLISubcommand):
"""The `serve` subcommand for the sgl-diffusion CLI"""
def __init__(self) -> None:
self.name = "serve"
super().__init__()
def cmd(
self, args: argparse.Namespace, unknown_args: list[str] | None = None
) -> None:
execute_serve_cmd(args, unknown_args)
def validate(self, args: argparse.Namespace) -> None:
"""Validate the arguments for this command"""
if args.config and not os.path.exists(args.config):
raise ValueError(f"Config file not found: {args.config}")
def subparser_init(
self, subparsers: argparse._SubParsersAction
) -> FlexibleArgumentParser:
serve_parser = subparsers.add_parser(
"serve",
help="Launch the server and start FastAPI listener.",
usage="sgl_diffusion serve --model-path MODEL_PATH_OR_ID [OPTIONS]",
)
serve_parser = add_multimodal_gen_serve_args(serve_parser)
return cast(FlexibleArgumentParser, serve_parser)
def cmd_init() -> list[CLISubcommand]:
return [ServeSubcommand()]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
import subprocess
import sys
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class RaiseNotImplementedAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
raise NotImplementedError(f"The {option_string} option is not yet implemented")
def launch_distributed(
num_gpus: int, args: list[str], master_port: int | None = None
) -> int:
"""
Launch a distributed job with the given arguments
Args:
num_gpus: Number of GPUs to use
args: Arguments to pass to v1_sgl_diffusion_inference.py (defaults to sys.argv[1:])
master_port: Port for the master process (default: random)
"""
current_env = os.environ.copy()
python_executable = sys.executable
project_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../..")
)
main_script = os.path.join(
project_root, "sgl_diffusion/sample/v1_sgl_diffusion_inference.py"
)
cmd = [
python_executable,
"-m",
"torch.distributed.run",
f"--nproc_per_node={num_gpus}",
]
if master_port is not None:
cmd.append(f"--master_port={master_port}")
cmd.append(main_script)
cmd.extend(args)
logger.info("Running inference with %d GPU(s)", num_gpus)
logger.info("Launching command: %s", " ".join(cmd))
current_env["PYTHONIOENCODING"] = "utf-8"
process = subprocess.Popen(
cmd,
env=current_env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
bufsize=1,
encoding="utf-8",
errors="replace",
)
if process.stdout:
for line in iter(process.stdout.readline, ""):
print(line.strip())
return process.wait()
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
DiffGenerator module for sgl-diffusion.
This module provides a consolidated interface for generating videos using
diffusion models.
"""
import logging
import multiprocessing as mp
import os
import time
from copy import deepcopy
from typing import Any
import imageio
import numpy as np
import torch
import torchvision
from einops import rearrange
# Suppress verbose logging from imageio, which is triggered when saving images.
logging.getLogger("imageio").setLevel(logging.WARNING)
logging.getLogger("imageio_ffmpeg").setLevel(logging.WARNING)
# Suppress Pillow plugin import logs when app log level is DEBUG
logging.getLogger("PIL").setLevel(logging.WARNING)
logging.getLogger("PIL.Image").setLevel(logging.WARNING)
from sglang.multimodal_gen.configs.sample.base import DataType, SamplingParams
from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request
from sglang.multimodal_gen.runtime.launch_server import launch_server
from sglang.multimodal_gen.runtime.managers.schedulerbase import SchedulerBase
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import OutputBatch, Req
from sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs
from sglang.multimodal_gen.runtime.sync_scheduler_client import sync_scheduler_client
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
# TODO: move to somewhere appropriate
try:
# Set the start method to 'spawn' to avoid CUDA errors in forked processes.
# This must be done at the top level of the module, before any CUDA context
# or other processes are initialized.
mp.set_start_method("spawn", force=True)
except RuntimeError:
# The start method can only be set once per program execution.
pass
# TODO: rename
class DiffGenerator:
"""
A unified class for generating images/videos using diffusion models.
This class provides a simple interface for image/video generation with rich
customization options, similar to popular frameworks like HF Diffusers.
"""
def __init__(
self,
server_args: ServerArgs,
):
"""
Initialize the generator.
Args:
server_args: The inference arguments
"""
self.server_args = server_args
self.port_args = PortArgs.from_server_args(server_args)
# The executor is now a client to the Scheduler service
self.local_scheduler_process: list[mp.Process] | None = None
self.owns_scheduler_client: bool = False
@classmethod
def from_pretrained(
cls,
**kwargs,
) -> "DiffGenerator":
"""
Create a DiffGenerator from a pretrained model.
Args:
**kwargs: Additional arguments to customize model loading, set any ServerArgs or PipelineConfig attributes here.
Returns:
The created DiffGenerator
Priority level: Default pipeline config < User's pipeline config < User's kwargs
"""
# If users also provide some kwargs, it will override the ServerArgs and PipelineConfig.
if (server_args := kwargs.get("server_args", None)) is not None:
if isinstance(server_args, ServerArgs):
pass
elif isinstance(server_args, dict):
server_args = ServerArgs.from_kwargs(**server_args)
else:
server_args = ServerArgs.from_kwargs(**kwargs)
return cls.from_server_args(server_args)
@classmethod
def from_server_args(cls, server_args: ServerArgs) -> "DiffGenerator":
"""
Create a DiffGenerator with the specified arguments.
Args:
server_args: The inference arguments
Returns:
The created DiffGenerator
"""
executor_class = SchedulerBase.get_class(server_args)
instance = cls(
server_args=server_args,
)
is_local_mode = server_args.is_local_mode
logger.info(f"Local mode: {is_local_mode}")
if is_local_mode:
instance.local_scheduler_process = instance._start_local_server_if_needed()
else:
# In remote mode, we just need to connect and check.
sync_scheduler_client.initialize(server_args)
instance._check_remote_scheduler()
# In both modes, this DiffGenerator instance is responsible for the client's lifecycle.
instance.owns_scheduler_client = True
return instance
def _start_local_server_if_needed(
self,
) -> list[mp.Process]:
"""Check if a local server is running; if not, start it and return the process handles."""
# First, we need a client to test the server. Initialize it temporarily.
sync_scheduler_client.initialize(self.server_args)
processes = launch_server(self.server_args, launch_http_server=False)
return processes
def _check_remote_scheduler(self):
"""Check if the remote scheduler is accessible."""
if not sync_scheduler_client.ping():
raise ConnectionError(
f"Could not connect to remote scheduler at "
f"{self.server_args.scheduler_endpoint()} with `local mode` as False. "
"Please ensure the server is running."
)
logger.info(
f"Successfully connected to remote scheduler at "
f"{self.server_args.scheduler_endpoint()}."
)
def post_process_sample(
self,
sample: torch.Tensor,
data_type: DataType,
fps: int,
save_output: bool = True,
save_file_path: str = None,
):
"""
Process a single sample output and save output if necessary
"""
# Process outputs
if sample.dim() == 3:
# for images, dim t is missing
sample = sample.unsqueeze(1)
sample = rearrange(sample, "c t h w -> t c h w")
frames = []
# TODO: this can be batched
for x in sample:
x = torchvision.utils.make_grid(x, nrow=6)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
frames.append((x * 255).numpy().astype(np.uint8))
# Save outputs if requested
if save_output:
if save_file_path:
os.makedirs(os.path.dirname(save_file_path), exist_ok=True)
if data_type == DataType.VIDEO:
imageio.mimsave(
save_file_path,
frames,
fps=fps,
format=data_type.get_default_extension(),
)
else:
imageio.imwrite(save_file_path, frames[0])
logger.info("Saved output to %s", save_file_path)
else:
logger.warning("No output path provided, output not saved")
return frames
def generate(
self,
prompt: str | list[str] | None = None,
sampling_params: SamplingParams | None = None,
**kwargs,
) -> dict[str, Any] | list[np.ndarray] | list[dict[str, Any]] | None:
"""
Generate a image/video based on the given prompt.
Args:
prompt: The prompt to use for generation (optional if prompt_txt is provided)
output_file_name: Name of the file to save. Default is the first 100 characters of the prompt.
save_output: Whether to save the output to disk
return_frames: Whether to return the raw frames
num_inference_steps: Number of denoising steps (overrides server_args)
guidance_scale: Classifier-free guidance scale (overrides server_args)
num_frames: Number of frames to generate (overrides server_args)
height: Height of generated file (overrides server_args)
width: Width of generated file (overrides server_args)
fps: Frames per second for saved file (overrides server_args)
seed: Random seed for generation (overrides server_args)
callback: Callback function called after each step
callback_steps: Number of steps between each callback
Returns:
Either the output dictionary, list of frames, or list of results for batch processing
"""
# 1. prepare requests
prompts: list[str] = []
# Handle batch processing from text file
if self.server_args.prompt_file_path is not None:
prompt_txt_path = self.server_args.prompt_file_path
if not os.path.exists(prompt_txt_path):
raise FileNotFoundError(
f"Prompt text file not found: {prompt_txt_path}"
)
# Read prompts from file
with open(prompt_txt_path, encoding="utf-8") as f:
prompts.extend(line.strip() for line in f if line.strip())
if not prompts:
raise ValueError(f"No prompts found in file: {prompt_txt_path}")
logger.info("Found %d prompts in %s", len(prompts), prompt_txt_path)
elif prompt is not None:
if isinstance(prompt, str):
prompts.append(prompt)
elif isinstance(prompt, list):
prompts.extend(prompt)
else:
raise ValueError("Either prompt or prompt_txt must be provided")
pretrained_sampling_params = SamplingParams.from_pretrained(
self.server_args.model_path, **kwargs
)
pretrained_sampling_params._merge_with_user_params(sampling_params)
# TODO: simplify
data_type = (
DataType.IMAGE
if self.server_args.pipeline_config.is_image_gen
or sampling_params.num_frames == 1
else DataType.VIDEO
)
sampling_params.data_type = data_type
pretrained_sampling_params.set_output_file_name()
requests: list[Req] = []
for output_idx, p in enumerate(prompts):
current_sampling_params = deepcopy(pretrained_sampling_params)
current_sampling_params.prompt = p
requests.append(
prepare_request(
p,
server_args=self.server_args,
sampling_params=current_sampling_params,
)
)
results = []
total_start_time = time.perf_counter()
# 2. send requests to scheduler, one at a time
# TODO: send batch when supported
for request_idx, req in enumerate(requests):
logger.info(
"Processing prompt %d/%d: %s...",
request_idx + 1,
len(requests),
req.prompt[:100],
)
try:
start_time = time.perf_counter()
output_batch = self._send_to_scheduler_and_wait_for_response([req])
gen_time = time.perf_counter() - start_time
if output_batch.error:
raise Exception(f"{output_batch.error}")
# FIXME: in generate mode, an internal assertion error won't raise an error
logger.info(
"Pixel data generated successfully in %.2f seconds",
gen_time,
)
if output_batch.output is None:
logger.error(
"Received empty output from scheduler for prompt %d",
request_idx + 1,
)
continue
for output_idx, sample in enumerate(output_batch.output):
num_outputs = len(output_batch.output)
output_file_name = req.output_file_name
if num_outputs > 1 and output_file_name:
base, ext = os.path.splitext(output_file_name)
output_file_name = f"{base}_{output_idx}{ext}"
save_path = (
os.path.join(req.output_path, output_file_name)
if output_file_name
else None
)
frames = self.post_process_sample(
sample,
fps=req.fps,
save_output=req.save_output,
save_file_path=save_path,
data_type=req.data_type,
)
result_item: dict[str, Any] = {
"samples": sample,
"frames": frames,
"prompts": req.prompt,
"size": (req.height, req.width, req.num_frames),
"generation_time": gen_time,
"logging_info": output_batch.logging_info,
"trajectory": output_batch.trajectory_latents,
"trajectory_timesteps": output_batch.trajectory_timesteps,
"trajectory_decoded": output_batch.trajectory_decoded,
"prompt_index": output_idx,
}
results.append(result_item)
except Exception as e:
logger.error(
"Failed to generate output for prompt %d: %s", request_idx + 1, e
)
continue
total_gen_time = time.perf_counter() - total_start_time
logger.info(
"Completed batch processing. Generated %d outputs in %.2f seconds.",
len(results),
total_gen_time,
)
if len(results) == 0:
return None
else:
if requests[0].return_frames:
results = [r["frames"] for r in results]
if len(results) == 1:
return results[0]
return results
def _send_to_scheduler_and_wait_for_response(self, batch: list[Req]) -> OutputBatch:
"""
Sends a request to the scheduler and waits for a response.
"""
return sync_scheduler_client.forward(batch)
def set_lora_adapter(
self, lora_nickname: str, lora_path: str | None = None
) -> None:
# self.scheduler.set_lora_adapter(lora_nickname, lora_path)
pass # Removed as per edit hint
def unmerge_lora_weights(self) -> None:
"""
Use unmerged weights for inference to produce outputs that align with
validation outputs generated during training.
"""
# self.scheduler.unmerge_lora_weights()
pass # Removed as per edit hint
def merge_lora_weights(self) -> None:
# self.scheduler.merge_lora_weights()
pass # Removed as per edit hint
def shutdown(self):
"""
Shutdown the generator.
If in local mode, it also shuts down the scheduler server.
"""
# This sends the shutdown command to the server
# self.scheduler.shutdown()
if self.local_scheduler_process:
logger.info("Waiting for local worker processes to terminate...")
for process in self.local_scheduler_process:
process.join(timeout=10)
if process.is_alive():
logger.warning(
f"Local worker {process.name} did not terminate gracefully, forcing."
)
process.terminate()
self.local_scheduler_process = None
if self.owns_scheduler_client:
sync_scheduler_client.close()
self.owns_scheduler_client = False
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.shutdown()
def __del__(self):
if self.owns_scheduler_client:
logger.warning(
"Generator was garbage collected without being shut down. "
"Attempting to shut down the local server and client."
)
self.shutdown()
elif self.local_scheduler_process:
logger.warning(
"Generator was garbage collected without being shut down. "
"Attempting to shut down the local server."
)
self.shutdown()
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI
from sglang.multimodal_gen.runtime.entrypoints.openai import image_api, video_api
from sglang.multimodal_gen.runtime.server_args import ServerArgs, prepare_server_args
from sglang.multimodal_gen.runtime.utils.logging_utils import configure_logger
@asynccontextmanager
async def lifespan(app: FastAPI):
from sglang.multimodal_gen.runtime.scheduler_client import (
run_zeromq_broker,
scheduler_client,
)
# 1. Initialize the singleton client that connects to the backend Scheduler
server_args = app.state.server_args
scheduler_client.initialize(server_args)
# 2. Start the ZMQ Broker in the background to handle offline requests
broker_task = asyncio.create_task(run_zeromq_broker(server_args))
yield
# On shutdown
print("FastAPI app is shutting down...")
broker_task.cancel()
scheduler_client.close()
def create_app(server_args: ServerArgs):
"""
Create and configure the FastAPI application instance.
"""
app = FastAPI(lifespan=lifespan)
app.include_router(image_api.router)
app.include_router(video_api.router)
app.state.server_args = server_args
return app
if __name__ == "__main__":
import uvicorn
server_args = prepare_server_args([])
configure_logger(server_args)
app = create_app(server_args)
uvicorn.run(
app,
host=server_args.host,
port=server_args.port,
log_config=None,
reload=False, # Set to True during development for auto-reloading
)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import base64
import os
import time
from typing import List, Optional
from fastapi import APIRouter, File, Form, HTTPException, Path, Query, UploadFile
from fastapi.responses import FileResponse
from sglang.multimodal_gen.configs.sample.base import (
SamplingParams,
generate_request_id,
)
from sglang.multimodal_gen.runtime.entrypoints.openai.protocol import (
ImageGenerationsRequest,
ImageResponse,
ImageResponseData,
)
from sglang.multimodal_gen.runtime.entrypoints.openai.stores import IMAGE_STORE
from sglang.multimodal_gen.runtime.entrypoints.openai.utils import (
_parse_size,
_save_upload_to_path,
post_process_sample,
)
from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.scheduler_client import scheduler_client
from sglang.multimodal_gen.runtime.server_args import get_global_server_args
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
router = APIRouter(prefix="/v1/images", tags=["images"])
logger = init_logger(__name__)
def _choose_ext(output_format: Optional[str], background: Optional[str]) -> str:
# Normalize and choose extension
fmt = (output_format or "").lower()
if fmt in {"png", "webp", "jpeg", "jpg"}:
return "jpg" if fmt == "jpeg" else fmt
# If transparency requested, prefer png
if (background or "auto").lower() == "transparent":
return "png"
# Default
return "jpg"
def _build_sampling_params_from_request(
request_id: str,
prompt: str,
n: int,
size: Optional[str],
output_format: Optional[str],
background: Optional[str],
image_path: Optional[str] = None,
) -> SamplingParams:
width, height = _parse_size(size)
ext = _choose_ext(output_format, background)
server_args = get_global_server_args()
sampling_params = SamplingParams.from_pretrained(server_args.model_path)
# Build user params
user_params = SamplingParams(
request_id=request_id,
prompt=prompt,
image_path=image_path,
num_frames=1, # image
width=width,
height=height,
num_outputs_per_prompt=max(1, min(int(n or 1), 10)),
save_output=True,
)
# Let SamplingParams auto-generate a file name, then force desired extension
sampling_params = sampling_params.from_user_sampling_params(user_params)
if not sampling_params.output_file_name:
sampling_params.output_file_name = request_id
if not sampling_params.output_file_name.endswith(f".{ext}"):
# strip any existing extension and apply desired one
base = sampling_params.output_file_name.rsplit(".", 1)[0]
sampling_params.output_file_name = f"{base}.{ext}"
sampling_params.log(server_args)
return sampling_params
def _build_req_from_sampling(s: SamplingParams) -> Req:
return Req(
request_id=s.request_id,
data_type=s.data_type,
prompt=s.prompt,
image_path=s.image_path,
height=s.height,
width=s.width,
fps=1,
num_frames=s.num_frames,
seed=s.seed,
output_path=s.output_path,
output_file_name=s.output_file_name,
num_outputs_per_prompt=s.num_outputs_per_prompt,
save_output=s.save_output,
)
@router.post("/generations", response_model=ImageResponse)
async def generations(
request: ImageGenerationsRequest,
):
request_id = generate_request_id()
sampling = _build_sampling_params_from_request(
request_id=request_id,
prompt=request.prompt,
n=request.n or 1,
size=request.size,
output_format=request.output_format,
background=request.background,
)
batch = prepare_request(
prompt=request.prompt,
server_args=get_global_server_args(),
sampling_params=sampling,
)
# Run synchronously for images and save to disk
result = await scheduler_client.forward([batch])
save_file_path = os.path.join(batch.output_path, batch.output_file_name)
post_process_sample(
result.output[0],
batch.data_type,
1,
batch.save_output,
save_file_path,
)
await IMAGE_STORE.upsert(
request_id,
{
"id": request_id,
"created_at": int(time.time()),
"file_path": save_file_path,
},
)
resp_format = (request.response_format or "b64_json").lower()
if resp_format == "b64_json":
with open(save_file_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("utf-8")
return ImageResponse(
data=[
ImageResponseData(
b64_json=b64,
revised_prompt=request.prompt,
)
]
)
else:
# Return error, not supported
raise HTTPException(
status_code=400, detail="response_format=url is not supported"
)
@router.post("/edits", response_model=ImageResponse)
async def edits(
image: Optional[List[UploadFile]] = File(None),
image_array: Optional[List[UploadFile]] = File(None, alias="image[]"),
prompt: str = Form(...),
mask: Optional[UploadFile] = File(None),
model: Optional[str] = Form(None),
n: Optional[int] = Form(1),
response_format: Optional[str] = Form(None),
size: Optional[str] = Form("1024x1024"),
output_format: Optional[str] = Form(None),
background: Optional[str] = Form("auto"),
user: Optional[str] = Form(None),
):
request_id = generate_request_id()
# Resolve images from either `image` or `image[]` (OpenAI SDK sends `image[]` when list is provided)
images = image or image_array
if not images or len(images) == 0:
raise HTTPException(status_code=422, detail="Field 'image' is required")
# Save first input image; additional images or mask are not yet used by the pipeline
uploads_dir = os.path.join("outputs", "uploads")
os.makedirs(uploads_dir, exist_ok=True)
first_image = images[0]
input_path = os.path.join(uploads_dir, f"{request_id}_{first_image.filename}")
await _save_upload_to_path(first_image, input_path)
sampling = _build_sampling_params_from_request(
request_id=request_id,
prompt=prompt,
n=n or 1,
size=size,
output_format=output_format,
background=background,
image_path=input_path,
)
batch = _build_req_from_sampling(sampling)
result = await scheduler_client.forward([batch])
save_file_path = os.path.join(batch.output_path, batch.output_file_name)
post_process_sample(
result.output[0],
batch.data_type,
1,
batch.save_output,
save_file_path,
)
await IMAGE_STORE.upsert(
request_id,
{
"id": request_id,
"created_at": int(time.time()),
"file_path": save_file_path,
},
)
# Default to b64_json to align with gpt-image-1 behavior in OpenAI examples
if (response_format or "b64_json").lower() == "b64_json":
with open(save_file_path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("utf-8")
return ImageResponse(
data=[ImageResponseData(b64_json=b64, revised_prompt=prompt)]
)
else:
url = f"/v1/images/{request_id}/content"
return ImageResponse(data=[ImageResponseData(url=url, revised_prompt=prompt)])
@router.get("/{image_id}/content")
async def download_image_content(
image_id: str = Path(...), variant: Optional[str] = Query(None)
):
item = await IMAGE_STORE.get(image_id)
if not item:
raise HTTPException(status_code=404, detail="Image not found")
file_path = item.get("file_path")
if not file_path or not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="Image is still being generated")
ext = os.path.splitext(file_path)[1].lower()
media_type = "image/jpeg"
if ext == ".png":
media_type = "image/png"
elif ext == ".webp":
media_type = "image/webp"
return FileResponse(
path=file_path, media_type=media_type, filename=os.path.basename(file_path)
)
import time
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
# Image API protocol models
class ImageResponseData(BaseModel):
b64_json: Optional[str] = None
url: Optional[str] = None
revised_prompt: Optional[str] = None
class ImageResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
data: List[ImageResponseData]
class ImageGenerationsRequest(BaseModel):
prompt: str
model: Optional[str] = None
n: Optional[int] = 1
quality: Optional[str] = "auto"
response_format: Optional[str] = "url" # url | b64_json
size: Optional[str] = "1024x1024" # e.g., 1024x1024
style: Optional[str] = "vivid"
background: Optional[str] = "auto" # transparent | opaque | auto
output_format: Optional[str] = None # png | jpeg | webp
user: Optional[str] = None
# Video API protocol models
class VideoResponse(BaseModel):
id: str
object: str = "video"
model: str = "sora-2"
status: str = "queued"
progress: int = 0
created_at: int = Field(default_factory=lambda: int(time.time()))
size: str = "720x1280"
seconds: str = "4"
quality: str = "standard"
remixed_from_video_id: Optional[str] = None
completed_at: Optional[int] = None
expires_at: Optional[int] = None
error: Optional[Dict[str, Any]] = None
class VideoGenerationsRequest(BaseModel):
prompt: str
input_reference: Optional[str] = None
model: Optional[str] = None
seconds: Optional[int] = 4
size: Optional[str] = "720x1280"
fps: Optional[int] = None
num_frames: Optional[int] = None
class VideoListResponse(BaseModel):
data: List[VideoResponse]
object: str = "list"
class VideoRemixRequest(BaseModel):
prompt: str
import asyncio
from typing import Any, Dict, List, Optional
class AsyncDictStore:
"""A small async-safe in-memory key-value store for dict items.
This encapsulates the usual pattern of a module-level dict guarded by
an asyncio.Lock and provides simple CRUD methods that are safe to call
concurrently from FastAPI request handlers and background tasks.
"""
def __init__(self) -> None:
self._items: Dict[str, Dict[str, Any]] = {}
self._lock = asyncio.Lock()
async def upsert(self, key: str, value: Dict[str, Any]) -> None:
async with self._lock:
self._items[key] = value
async def update_fields(
self, key: str, updates: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
async with self._lock:
item = self._items.get(key)
if item is None:
return None
item.update(updates)
return item
async def get(self, key: str) -> Optional[Dict[str, Any]]:
async with self._lock:
return self._items.get(key)
async def pop(self, key: str) -> Optional[Dict[str, Any]]:
async with self._lock:
return self._items.pop(key, None)
async def list_values(self) -> List[Dict[str, Any]]:
async with self._lock:
return list(self._items.values())
# Global stores shared by OpenAI entrypoints
VIDEO_STORE = AsyncDictStore()
IMAGE_STORE = AsyncDictStore()
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import os
import imageio
import numpy as np
import torch
import torchvision
from einops import rearrange
from fastapi import UploadFile
from sglang.multimodal_gen.configs.sample.base import DataType
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
def post_process_sample(
sample: torch.Tensor,
data_type: DataType,
fps: int,
save_output: bool = True,
save_file_path: str = None,
):
"""
Process sample output and save video if necessary
"""
# Process outputs
if sample.dim() == 3:
# for images, dim t is missing
sample = sample.unsqueeze(1)
videos = rearrange(sample, "c t h w -> t c h w")
frames = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=6)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
frames.append((x * 255).numpy().astype(np.uint8))
# Save outputs if requested
if save_output:
if save_file_path:
os.makedirs(os.path.dirname(save_file_path), exist_ok=True)
if data_type == DataType.VIDEO:
imageio.mimsave(
save_file_path,
frames,
fps=fps,
format=data_type.get_default_extension(),
)
else:
imageio.imwrite(save_file_path, frames[0])
logger.info(f"Saved output to {save_file_path}")
else:
logger.info(f"No output path provided, output not saved")
return frames
def _parse_size(size: str) -> tuple[int, int]:
try:
parts = size.lower().replace(" ", "").split("x")
if len(parts) != 2:
raise ValueError
w, h = int(parts[0]), int(parts[1])
return w, h
except Exception:
# Fallback to default portrait 720x1280
return 720, 1280
# Helpers
async def _save_upload_to_path(upload: UploadFile, target_path: str) -> str:
os.makedirs(os.path.dirname(target_path), exist_ok=True)
content = await upload.read()
with open(target_path, "wb") as f:
f.write(content)
return target_path
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import asyncio
import json
import os
import time
from typing import Any, Dict, Optional
from fastapi import (
APIRouter,
File,
Form,
HTTPException,
Path,
Query,
Request,
UploadFile,
)
from fastapi.responses import FileResponse
from sglang.multimodal_gen.configs.sample.base import (
SamplingParams,
generate_request_id,
)
from sglang.multimodal_gen.runtime.entrypoints.openai.protocol import (
VideoGenerationsRequest,
VideoListResponse,
VideoResponse,
)
from sglang.multimodal_gen.runtime.entrypoints.openai.stores import VIDEO_STORE
from sglang.multimodal_gen.runtime.entrypoints.openai.utils import (
_parse_size,
_save_upload_to_path,
post_process_sample,
)
from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.server_args import get_global_server_args
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
router = APIRouter(prefix="/v1/videos", tags=["videos"])
def _build_sampling_params_from_request(
request_id: str, request: VideoGenerationsRequest
) -> SamplingParams:
width, height = _parse_size(request.size or "720x1280")
seconds = request.seconds if request.seconds is not None else 4
# Prefer user-provided fps/num_frames from request; fallback to defaults
fps_default = 24
fps = request.fps if request.fps is not None else fps_default
# If user provides num_frames, use it directly; otherwise derive from seconds * fps
derived_num_frames = fps * seconds
num_frames = (
request.num_frames if request.num_frames is not None else derived_num_frames
)
server_args = get_global_server_args()
# TODO: should we cache this sampling_params?
sampling_params = SamplingParams.from_pretrained(server_args.model_path)
user_params = SamplingParams(
request_id=request_id,
prompt=request.prompt,
num_frames=num_frames,
fps=fps,
width=width,
height=height,
image_path=request.input_reference,
save_output=True,
)
sampling_params = sampling_params.from_user_sampling_params(user_params)
sampling_params.log(server_args)
sampling_params.set_output_file_ext()
return sampling_params
# extract metadata which http_server needs to know
def _video_job_from_sampling(
request_id: str, req: VideoGenerationsRequest, sampling: SamplingParams
) -> Dict[str, Any]:
size_str = f"{sampling.width}x{sampling.height}"
seconds = int(round((sampling.num_frames or 0) / float(sampling.fps or 24)))
return {
"id": request_id,
"object": "video",
"model": req.model or "sora-2",
"status": "queued",
"progress": 0,
"created_at": int(time.time()),
"size": size_str,
"seconds": str(seconds),
"quality": "standard",
"file_path": sampling.output_file_path(),
}
async def _dispatch_job_async(job_id: str, batch: Req) -> None:
from sglang.multimodal_gen.runtime.scheduler_client import scheduler_client
try:
result = await scheduler_client.forward([batch])
post_process_sample(
result.output[0],
batch.data_type,
batch.fps,
batch.save_output,
os.path.join(batch.output_path, batch.output_file_name),
)
await VIDEO_STORE.update_fields(
job_id,
{"status": "completed", "progress": 100, "completed_at": int(time.time())},
)
except Exception as e:
logger.error(f"{e}")
await VIDEO_STORE.update_fields(
job_id, {"status": "failed", "error": {"message": str(e)}}
)
# TODO: support image to video generation
@router.post("", response_model=VideoResponse)
async def create_video(
request: Request,
# multipart/form-data fields (optional; used only when content-type is multipart)
prompt: Optional[str] = Form(None),
input_reference: Optional[UploadFile] = File(None),
model: Optional[str] = Form(None),
seconds: Optional[int] = Form(None),
size: Optional[str] = Form(None),
fps: Optional[int] = Form(None),
num_frames: Optional[int] = Form(None),
extra_body: Optional[str] = Form(None),
):
content_type = request.headers.get("content-type", "").lower()
request_id = generate_request_id()
if "multipart/form-data" in content_type:
if not prompt:
raise HTTPException(status_code=400, detail="prompt is required")
if input_reference is None:
raise HTTPException(
status_code=400, detail="input_reference file is required"
)
uploads_dir = os.path.join("outputs", "uploads")
os.makedirs(uploads_dir, exist_ok=True)
input_path = os.path.join(
uploads_dir, f"{request_id}_{input_reference.filename}"
)
await _save_upload_to_path(input_reference, input_path)
# Parse extra_body JSON (if provided in multipart form) to get fps/num_frames overrides
extra_from_form: Dict[str, Any] = {}
if extra_body:
try:
extra_from_form = json.loads(extra_body)
except Exception:
extra_from_form = {}
fps_val = fps if fps is not None else extra_from_form.get("fps")
num_frames_val = (
num_frames if num_frames is not None else extra_from_form.get("num_frames")
)
req = VideoGenerationsRequest(
prompt=prompt,
input_reference=input_path,
model=model,
seconds=seconds if seconds is not None else 4,
size=size or "720x1280",
fps=fps_val,
num_frames=num_frames_val,
)
else:
try:
body = await request.json()
except Exception:
body = {}
try:
# If client uses extra_body, merge it into the top-level payload
payload: Dict[str, Any] = dict(body or {})
extra = payload.pop("extra_body", None)
if isinstance(extra, dict):
# Shallow-merge: only keys like fps/num_frames are expected
payload.update(extra)
req = VideoGenerationsRequest(**payload)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid request body: {e}")
logger.debug(f"Server received from create_video endpoint: req={req}")
sampling_params = _build_sampling_params_from_request(request_id, req)
job = _video_job_from_sampling(request_id, req, sampling_params)
await VIDEO_STORE.upsert(request_id, job)
# Build Req for scheduler
batch = prepare_request(
prompt=req.prompt,
server_args=get_global_server_args(),
sampling_params=sampling_params,
)
# Enqueue the job asynchronously and return immediately
asyncio.create_task(_dispatch_job_async(request_id, batch))
return VideoResponse(**job)
@router.get("", response_model=VideoListResponse)
async def list_videos(
after: Optional[str] = Query(None),
limit: Optional[int] = Query(None, ge=1, le=100),
order: Optional[str] = Query("desc"),
):
# Normalize order
order = (order or "desc").lower()
if order not in ("asc", "desc"):
order = "desc"
jobs = await VIDEO_STORE.list_values()
reverse = order != "asc"
jobs.sort(key=lambda j: j.get("created_at", 0), reverse=reverse)
if after is not None:
try:
idx = next(i for i, j in enumerate(jobs) if j["id"] == after)
jobs = jobs[idx + 1 :]
except StopIteration:
jobs = []
if limit is not None:
jobs = jobs[:limit]
items = [VideoResponse(**j) for j in jobs]
return VideoListResponse(data=items)
@router.get("/{video_id}", response_model=VideoResponse)
async def retrieve_video(video_id: str = Path(...)):
job = await VIDEO_STORE.get(video_id)
if not job:
raise HTTPException(status_code=404, detail="Video not found")
return VideoResponse(**job)
# TODO: support aborting a job.
@router.delete("/{video_id}", response_model=VideoResponse)
async def delete_video(video_id: str = Path(...)):
job = await VIDEO_STORE.pop(video_id)
if not job:
raise HTTPException(status_code=404, detail="Video not found")
# Mark as deleted in response semantics
job["status"] = "deleted"
return VideoResponse(**job)
@router.get("/{video_id}/content")
async def download_video_content(
video_id: str = Path(...), variant: Optional[str] = Query(None)
):
job = await VIDEO_STORE.get(video_id)
if not job:
raise HTTPException(status_code=404, detail="Video not found")
file_path = job.get("file_path")
if not file_path or not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="Generation is still in-progress")
media_type = "video/mp4" # default variant
return FileResponse(
path=file_path, media_type=media_type, filename=os.path.basename(file_path)
)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
DiffGenerator module for sgl-diffusion.
This module provides a consolidated interface for generating videos using
diffusion models.
"""
import logging
import math
# Suppress verbose logging from imageio, which is triggered when saving images.
logging.getLogger("imageio").setLevel(logging.WARNING)
logging.getLogger("imageio_ffmpeg").setLevel(logging.WARNING)
from sglang.multimodal_gen.configs.sample.base import DataType, SamplingParams
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import shallow_asdict
logger = init_logger(__name__)
def prepare_sampling_params(
prompt: str,
server_args: ServerArgs,
sampling_params: SamplingParams,
):
pipeline_config = server_args.pipeline_config
# Validate inputs
if not isinstance(prompt, str):
raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
# Process negative prompt
if (
sampling_params.negative_prompt is not None
and not sampling_params.negative_prompt.isspace()
):
# avoid stripping default negative prompt: ' ' for qwen-image
sampling_params.negative_prompt = sampling_params.negative_prompt.strip()
# Validate dimensions
if sampling_params.num_frames <= 0:
raise ValueError(
f"Height, width, and num_frames must be positive integers, got "
f"height={sampling_params.height}, width={sampling_params.width}, "
f"num_frames={sampling_params.num_frames}"
)
temporal_scale_factor = (
pipeline_config.vae_config.arch_config.temporal_compression_ratio
)
# settle num_frames
if server_args.pipeline_config.is_image_gen:
logger.debug(f"Setting num_frames to 1 because this is a image-gen model")
sampling_params.num_frames = 1
num_frames = sampling_params.num_frames
num_gpus = server_args.num_gpus
use_temporal_scaling_frames = pipeline_config.vae_config.use_temporal_scaling_frames
# Adjust number of frames based on number of GPUs
if use_temporal_scaling_frames:
orig_latent_num_frames = (num_frames - 1) // temporal_scale_factor + 1
else: # stepvideo only
orig_latent_num_frames = sampling_params.num_frames // 17 * 3
if orig_latent_num_frames % server_args.num_gpus != 0:
# Adjust latent frames to be divisible by number of GPUs
if sampling_params.num_frames_round_down:
# Ensure we have at least 1 batch per GPU
new_latent_num_frames = (
max(1, (orig_latent_num_frames // num_gpus)) * num_gpus
)
else:
new_latent_num_frames = (
math.ceil(orig_latent_num_frames / num_gpus) * num_gpus
)
if use_temporal_scaling_frames:
# Convert back to number of frames, ensuring num_frames-1 is a multiple of temporal_scale_factor
new_num_frames = (new_latent_num_frames - 1) * temporal_scale_factor + 1
else: # stepvideo only
# Find the least common multiple of 3 and num_gpus
divisor = math.lcm(3, num_gpus)
# Round up to the nearest multiple of this LCM
new_latent_num_frames = (
(new_latent_num_frames + divisor - 1) // divisor
) * divisor
# Convert back to actual frames using the StepVideo formula
new_num_frames = new_latent_num_frames // 3 * 17
logger.info(
"Adjusting number of frames from %s to %s based on number of GPUs (%s)",
sampling_params.num_frames,
new_num_frames,
server_args.num_gpus,
)
sampling_params.num_frames = new_num_frames
if pipeline_config.is_image_gen:
sampling_params.data_type = DataType.IMAGE
sampling_params.set_output_file_ext()
sampling_params.log(server_args=server_args)
return sampling_params
def prepare_request(
prompt: str,
server_args: ServerArgs,
sampling_params: SamplingParams,
) -> Req:
"""
Settle SamplingParams according to ServerArgs
"""
# Create a copy of inference args to avoid modifying the original
sampling_params = prepare_sampling_params(prompt, server_args, sampling_params)
req = Req(
**shallow_asdict(sampling_params),
VSA_sparsity=server_args.VSA_sparsity,
)
# req.set_width_and_height(server_args)
# if (req.width <= 0
# or req.height <= 0):
# raise ValueError(
# f"Height, width must be positive integers, got "
# f"height={req.height}, width={req.width}"
# )
return req
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import multiprocessing as mp
import uvicorn
from sglang.multimodal_gen.runtime.entrypoints.http_server import create_app
from sglang.multimodal_gen.runtime.managers.gpu_worker import run_scheduler_process
from sglang.multimodal_gen.runtime.server_args import ServerArgs, set_global_server_args
from sglang.multimodal_gen.runtime.utils.logging_utils import (
configure_logger,
logger,
suppress_other_loggers,
)
def launch_server(server_args: ServerArgs, launch_http_server: bool = True):
"""
Args:
launch_http_server: False for offline local mode
"""
configure_logger(server_args)
suppress_other_loggers()
# Start a new server with multiple worker processes
logger.info("Starting server...")
num_gpus = server_args.num_gpus
processes = []
# Pipes for master to talk to slaves
task_pipes_to_slaves_w = []
task_pipes_to_slaves_r = []
for _ in range(num_gpus - 1):
r, w = mp.Pipe(duplex=False)
task_pipes_to_slaves_r.append(r)
task_pipes_to_slaves_w.append(w)
# Pipes for slaves to talk to master
result_pipes_from_slaves_w = []
result_pipes_from_slaves_r = []
for _ in range(num_gpus - 1):
r, w = mp.Pipe(duplex=False)
result_pipes_from_slaves_r.append(r)
result_pipes_from_slaves_w.append(w)
# Launch all worker processes
master_port = server_args.master_port or (server_args.master_port + 100)
scheduler_pipe_readers = []
scheduler_pipe_writers = []
for i in range(num_gpus):
reader, writer = mp.Pipe(duplex=False)
scheduler_pipe_writers.append(writer)
if i == 0: # Master worker
process = mp.Process(
target=run_scheduler_process,
args=(
i, # local_rank
i, # rank
master_port,
server_args,
writer,
None, # No task pipe to read from master
None, # No result pipe to write to master
task_pipes_to_slaves_w,
result_pipes_from_slaves_r,
),
name=f"sgl-diffusionWorker-{i}",
daemon=True,
)
else: # Slave workers
process = mp.Process(
target=run_scheduler_process,
args=(
i, # local_rank
i, # rank
master_port,
server_args,
writer,
None, # No task pipe to read from master
None, # No result pipe to write to master
task_pipes_to_slaves_r[i - 1],
result_pipes_from_slaves_w[i - 1],
),
name=f"sgl-diffusionWorker-{i}",
daemon=True,
)
scheduler_pipe_readers.append(reader)
process.start()
processes.append(process)
# Wait for all workers to be ready
scheduler_infos = []
for writer in scheduler_pipe_writers:
writer.close()
# Close unused pipe ends in parent process
for p in task_pipes_to_slaves_w:
p.close()
for p in task_pipes_to_slaves_r:
p.close()
for p in result_pipes_from_slaves_w:
p.close()
for p in result_pipes_from_slaves_r:
p.close()
for i, reader in enumerate(scheduler_pipe_readers):
try:
data = reader.recv()
except EOFError:
logger.error(
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
)
processes[i].join()
logger.error(f"Exit code: {processes[i].exitcode}")
raise
if data["status"] != "ready":
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
scheduler_infos.append(data)
reader.close()
logger.debug("All workers are ready")
if launch_http_server:
logger.info("Starting FastAPI server.")
# set for endpoints to access global_server_args
set_global_server_args(server_args)
app = create_app(server_args)
uvicorn.run(
app,
log_config=None,
log_level=server_args.log_level,
host=server_args.host,
port=server_args.port,
reload=False,
)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment