Unverified Commit 7bc1dae0 authored by Mick's avatar Mick Committed by GitHub
Browse files

WIP: initial multimodal-gen support (#12484)


Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarJiLi <leege233@gmail.com>
Co-authored-by: default avatarCHEN Xi <78632976+RubiaCx@users.noreply.github.com>
Co-authored-by: default avatarlaixin <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarSolitaryThinker <wlsaidhi@gmail.com>
Co-authored-by: default avatarjzhang38 <a1286225768@gmail.com>
Co-authored-by: default avatarBrianChen1129 <yongqichcd@gmail.com>
Co-authored-by: default avatarKevin Lin <42618777+kevin314@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarrlsu9 <r3su@ucsd.edu>
Co-authored-by: default avatarJinzhe Pan <48981407+eigensystem@users.noreply.github.com>
Co-authored-by: default avatarforeverpiano <pianoqwz@qq.com>
Co-authored-by: default avatarRandNMR73 <notomatthew31@gmail.com>
Co-authored-by: default avatarPorridgeSwim <yz3883@columbia.edu>
Co-authored-by: default avatarJiali Chen <90408393+gary-chenjl@users.noreply.github.com>
parent 4fe53e58
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
from sglang.multimodal_gen.runtime.layers.activation import get_act_fn
from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear
class MLP(nn.Module):
"""
MLP for DiT blocks, NO gated linear units
"""
def __init__(
self,
input_dim: int,
mlp_hidden_dim: int,
output_dim: int | None = None,
bias: bool = True,
act_type: str = "gelu_pytorch_tanh",
dtype: torch.dtype | None = None,
prefix: str = "",
):
super().__init__()
self.fc_in = ReplicatedLinear(
input_dim,
mlp_hidden_dim, # For activation func like SiLU that need 2x width
bias=bias,
params_dtype=dtype,
)
self.act = get_act_fn(act_type)
if output_dim is None:
output_dim = input_dim
self.fc_out = ReplicatedLinear(
mlp_hidden_dim, output_dim, bias=bias, params_dtype=dtype
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc_in(x)
x = self.act(x)
x, _ = self.fc_out(x)
return x
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from typing import Literal, get_args
from sglang.multimodal_gen.runtime.layers.quantization.base_config import (
QuantizationConfig,
)
QuantizationMethods = Literal[None]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
# The customized quantization methods which will be added to this dict.
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
def register_quantization_config(quantization: str):
"""Register a customized vllm quantization config.
When a quantization method is not supported by vllm, you can register a customized
quantization config to support it.
Args:
quantization (str): The quantization method name.
Examples:
>>> from sglang.multimodal_gen.runtime.layers.quantization import register_quantization_config
>>> from sglang.multimodal_gen.runtime.layers.quantization import get_quantization_config
>>> from sglang.multimodal_gen.runtime.layers.quantization.base_config import QuantizationConfig
>>>
>>> @register_quantization_config("my_quant")
... class MyQuantConfig(QuantizationConfig):
... pass
>>>
>>> get_quantization_config("my_quant")
<class 'MyQuantConfig'>
""" # noqa: E501
def _wrapper(quant_config_cls):
if quantization in QUANTIZATION_METHODS:
raise ValueError(
f"The quantization method `{quantization}` is already exists."
)
if not issubclass(quant_config_cls, QuantizationConfig):
raise ValueError(
"The quantization config must be a subclass of " "`QuantizationConfig`."
)
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls
QUANTIZATION_METHODS.append(quantization)
return quant_config_cls
return _wrapper
def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
method_to_config: dict[str, type[QuantizationConfig]] = {}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
return method_to_config[quantization]
all = [
"QuantizationMethods",
"QuantizationConfig",
"get_quantization_config",
"QUANTIZATION_METHODS",
]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/quantization/base_config.py
import inspect
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
import torch
from torch import nn
if TYPE_CHECKING:
from sglang.multimodal_gen.runtime.layers.quantization import QuantizationMethods
else:
QuantizationMethods = str
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
@abstractmethod
def create_weights(
self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs
):
"""Create weights for a layer.
The weights will be set as attributes of the layer."""
raise NotImplementedError
@abstractmethod
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
# Not required functions
def embedding(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
"""Gather embeddings in the layer based on indices in the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
def method_has_implemented_embedding(method_class: type[QuantizeMethodBase]) -> bool:
"""
Not all quant methods have embedding implemented, so we need to check that
it exists for our given method. We check this by making sure the function
has been changed from the base implementation.
"""
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None)
class_embedding = inspect.getattr_static(method_class, "embedding", None)
return class_embedding is not None and class_embedding is not base_embedding
class QuantizationConfig(ABC):
"""Base class for quantization configs."""
def __init__(self):
super().__init__()
# mapping is updated by models as they initialize
self.packed_modules_mapping: dict[str, list[str]] = dict()
@abstractmethod
def get_name(self) -> QuantizationMethods:
"""Name of the quantization method."""
raise NotImplementedError
@abstractmethod
def get_supported_act_dtypes(self) -> list[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def get_config_filenames() -> list[str]:
"""List of filenames to search for in the model directory."""
raise NotImplementedError
@classmethod
@abstractmethod
def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config."""
raise NotImplementedError
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
"""
Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method --
this method should only be overwritten by subclasses in exceptional
circumstances
"""
return None
@staticmethod
def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
return config[key]
raise ValueError(
f"Cannot find any of {keys} in the model's " "quantization config."
)
@staticmethod
def get_from_keys_or(config: dict[str, Any], keys: list[str], default: Any) -> Any:
"""Get a optional value from the model's quantization config."""
try:
return QuantizationConfig.get_from_keys(config, keys)
except ValueError:
return default
@abstractmethod
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> QuantizeMethodBase | None:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
prefix: The full name of the layer in the state dict
Returns:
The quantize method. None if the given layer doesn't support quant
method.
"""
raise NotImplementedError
def get_cache_scale(self, name: str) -> str | None:
return None
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/rotary_embedding.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rotary Positional Embeddings."""
import functools
from collections import OrderedDict
from typing import Any
import torch
from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_group
from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp
from sglang.multimodal_gen.runtime.layers.triton_ops import apply_rotary_embedding
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
interleaved: bool = False,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size] or [num_tokens, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
# cos = cos.unsqueeze(-2).to(x.dtype)
# sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = (x1.float() * cos - x2.float() * sin).type_as(x)
o2 = (x2.float() * cos + x1.float() * sin).type_as(x)
return torch.cat((o1, o2), dim=-1)
else:
return apply_rotary_embedding(x, cos, sin, interleaved)
@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int | float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: int | float) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (
base
** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
)
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_cuda(self, *args, **kwargs) -> Any:
return self.forward_native(*args, **kwargs)
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s
class OneDRotaryEmbedding(torch.nn.Module):
"""1D rotary positional embedding with caching."""
def __init__(
self,
dim: int,
theta: float = 10000.0,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
dtype: torch.dtype = torch.float32,
use_real: bool = False,
repeat_interleave_real: bool = False,
):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.theta = theta
self.theta_rescale_factor = theta_rescale_factor
self.interpolation_factor = interpolation_factor
# dtype of freqs
self.dtype = dtype
self.use_real = use_real
self.repeat_interleave_real = repeat_interleave_real
def build_freqs(self, device):
freqs = 1.0 / (
self.theta
** (
torch.arange(0, self.dim, 2, dtype=self.dtype, device=device)[
: (self.dim // 2)
]
/ self.dim
).to(device=device)
)
return freqs
def build_freqs_outer(self, pos: torch.Tensor, device):
theta = self.theta
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if self.theta_rescale_factor != 1.0:
theta *= self.theta_rescale_factor ** (self.dim / (self.dim - 2))
freqs = self.build_freqs(device)
freqs = torch.outer(pos * self.interpolation_factor, freqs)
freqs_cos = freqs.cos()
freqs_sin = freqs.sin()
if self.use_real and self.repeat_interleave_real:
freqs_cos = freqs_cos.repeat_interleave(2, dim=1)
freqs_sin = freqs_sin.repeat_interleave(2, dim=1)
return freqs_cos.float(), freqs_sin.float()
@functools.lru_cache(maxsize=16)
def forward_from_grid(
self, seq_len: int, start_pos: int, device_str: str
) -> tuple[torch.Tensor, torch.Tensor]:
device = torch.device(device_str)
pos = torch.arange(
start_pos, start_pos + seq_len, dtype=self.dtype, device=device
)
freqs_cos, freqs_sin = self.build_freqs_outer(pos, device)
return freqs_cos, freqs_sin
def forward(self, pos: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Calculates 1D rotary embeddings for the given positions.
This method converts the input tensor to a hashable representation
and calls a cached helper method to perform the computation.
"""
pos_tuple = tuple(pos.tolist())
device_str = str(pos.device)
return self._forward_cached(pos_tuple, device_str)
@functools.lru_cache(maxsize=16)
def _forward_cached(
self, pos_tuple: tuple, device_str: str
) -> tuple[torch.Tensor, torch.Tensor]:
"""
The core implementation that computes 1D rotary embeddings.
This method is wrapped by an LRU cache.
"""
device = torch.device(device_str)
pos = torch.as_tensor(pos_tuple, dtype=self.dtype, device=device)
freqs_cos, freqs_sin = self.build_freqs_outer(pos, device)
return freqs_cos, freqs_sin
class NDRotaryEmbedding(torch.nn.Module):
"""N-dimensional rotary positional embedding."""
def __init__(
self,
rope_dim_list: list[int],
rope_theta: float,
theta_rescale_factor: float | list[float] = 1.0,
interpolation_factor: float | list[float] = 1.0,
use_real: bool = False,
repeat_interleave_real: bool = False,
dtype: torch.dtype = torch.float32,
):
super().__init__()
self.rope_dim_list = rope_dim_list
self.ndim = len(rope_dim_list)
self.rope_theta = rope_theta
# dtype of freqs
# does not control the output dtype
self.dtype = dtype
if isinstance(theta_rescale_factor, (int, float)):
self.theta_rescale_factor = [theta_rescale_factor] * self.ndim
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
self.theta_rescale_factor = [theta_rescale_factor[0]] * self.ndim
else:
self.theta_rescale_factor = theta_rescale_factor
assert (
len(self.theta_rescale_factor) == self.ndim
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, (int, float)):
self.interpolation_factor = [interpolation_factor] * self.ndim
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
self.interpolation_factor = [interpolation_factor[0]] * self.ndim
else:
self.interpolation_factor = interpolation_factor
assert (
len(self.interpolation_factor) == self.ndim
), "len(interpolation_factor) should equal to len(rope_dim_list)"
self.rope_generators: list[OneDRotaryEmbedding] = torch.nn.ModuleList()
_config_to_gen_idx: dict[tuple, int] = {}
self.dim_idx_to_gen_idx: list[int] = []
for i in range(self.ndim):
dim = self.rope_dim_list[i]
rescale = self.theta_rescale_factor[i]
interp = self.interpolation_factor[i]
config_key = (dim, rescale, interp, use_real, repeat_interleave_real)
if config_key not in _config_to_gen_idx:
generator = OneDRotaryEmbedding(
dim=dim,
theta=self.rope_theta,
theta_rescale_factor=rescale,
interpolation_factor=interp,
dtype=self.dtype,
use_real=use_real,
repeat_interleave_real=repeat_interleave_real,
)
_config_to_gen_idx[config_key] = len(self.rope_generators)
self.rope_generators.append(generator)
gen_idx = _config_to_gen_idx[config_key]
self.dim_idx_to_gen_idx.append(gen_idx)
def forward(self, positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Calculates n-d rotary embeddings for given absolute positions.
Args:
positions (torch.Tensor): A tensor of shape `[num_tokens, ndim]`
containing the integer coordinates for each token.
Returns:
A tuple of (cos, sin) tensors.
"""
# Caching wrapper: convert tensor to a hashable tuple of tuples.
pos_tuple = tuple(map(tuple, positions.tolist()))
device_str = str(positions.device)
return self._forward_cached(pos_tuple, device_str)
@functools.lru_cache(maxsize=16)
def _forward_cached(
self, pos_tuple: tuple[tuple[int, ...], ...], device_str: str
) -> tuple[torch.Tensor, torch.Tensor]:
"""
The core implementation that computes embeddings from a position tensor.
This method is wrapped by an LRU cache.
"""
device = torch.device(device_str)
positions = torch.tensor(pos_tuple, dtype=torch.long, device=device)
return self.forward_uncached(pos=positions)
def forward_uncached(self, pos: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
The core implementation that computes embeddings from a position tensor.
This method is wrapped by an LRU cache.
"""
device = pos.device
# Pre-allocate the final tensors for efficiency.
num_tokens = pos.shape[0]
first_generator = self.rope_generators[0]
if first_generator.use_real and first_generator.repeat_interleave_real:
head_dim = sum(self.rope_dim_list)
else:
head_dim = sum(self.rope_dim_list) // 2
cos = torch.empty((num_tokens, head_dim), device=device, dtype=self.dtype)
sin = torch.empty((num_tokens, head_dim), device=device, dtype=self.dtype)
col_offset = 0
for i in range(self.ndim):
# Extract position coordinates for the current dimension for all tokens.
pos_i = pos[:, i].to(self.dtype)
# Get the appropriate 1D generator.
gen_idx = self.dim_idx_to_gen_idx[i]
generator = self.rope_generators[gen_idx]
# Calculate 1D embeddings.
cos_1d, sin_1d = generator(pos_i)
slice_width = cos_1d.shape[1]
cos[:, col_offset : col_offset + slice_width] = cos_1d
sin[:, col_offset : col_offset + slice_width] = sin_1d
col_offset += slice_width
return cos.float(), sin.float()
def forward_from_grid(
self,
grid_size: tuple[int, ...],
shard_dim: int = 0,
start_frame: int = 0,
device: torch.device | str | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Caching wrapper: use grid parameters directly as the key.
# grid_tuple = _to_tuple(grid_size, dim=self.ndim)
device_str = str(device) if device is not None else "cpu"
return self._forward_cached_from_grid(
grid_size, shard_dim, start_frame, device_str
)
@functools.lru_cache(maxsize=16)
def _forward_cached_from_grid(
self,
grid_size: tuple[int, ...],
shard_dim: int,
start_frame: int,
device_str: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Computes embeddings for a structured grid, using a highly efficient
implementation that avoids materializing the full position tensor.
This method is wrapped by an LRU cache.
"""
device = torch.device(device_str)
sp_group = get_sp_group()
sp_rank = sp_group.rank_in_group
sp_world_size = sp_group.world_size
sizes = _to_tuple(grid_size, dim=self.ndim)
starts = (0,) * self.ndim
# Apply sequence parallel sharding to the sizes and compute shard offset
shard_sizes = list(sizes)
shard_offsets = [0] * self.ndim
if sp_world_size > 1:
assert sizes[shard_dim] % sp_world_size == 0, (
f"Dimension {shard_dim} with size {sizes[shard_dim]} is not divisible "
f"by sequence parallel world size {sp_world_size}"
)
shard_size = sizes[shard_dim] // sp_world_size
shard_offsets[shard_dim] = sp_rank * shard_size
shard_sizes[shard_dim] = shard_size
# Pre-allocate outputs on the requested device to avoid CPU ops and extra cats
num_tokens = 1
for s in shard_sizes:
num_tokens *= int(s)
head_dim_half = sum(self.rope_dim_list) // 2
cos = torch.empty((num_tokens, head_dim_half), device=device, dtype=self.dtype)
sin = torch.empty((num_tokens, head_dim_half), device=device, dtype=self.dtype)
# Compute per-axis 1D embeddings once and expand via repeats to [N, d_i/2]
col_offset = 0
for i in range(self.ndim):
dim_i = self.rope_dim_list[i]
dim_i_half = dim_i // 2
size_i = int(shard_sizes[i])
# Starting position for this axis, with optional frame offset for time axis (i==0)
base_offset = starts[i]
if i == 0 and start_frame > 0:
base_offset += start_frame
if sp_world_size > 1 and i == shard_dim:
base_offset += shard_offsets[i]
gen_idx = self.dim_idx_to_gen_idx[i]
generator = self.rope_generators[gen_idx]
cos_1d, sin_1d = generator.forward_from_grid(
size_i, base_offset, device_str
)
# Expand to [num_tokens, dim_i/2] matching flatten order (last dims vary fastest)
repeats_per_entry = 1
for j in range(i + 1, self.ndim):
repeats_per_entry *= int(shard_sizes[j])
tile_count = 1
for j in range(0, i):
tile_count *= int(shard_sizes[j])
cos_expanded = cos_1d.repeat_interleave(repeats_per_entry, dim=0)
sin_expanded = sin_1d.repeat_interleave(repeats_per_entry, dim=0)
if tile_count > 1:
cos_expanded = cos_expanded.repeat(tile_count, 1)
sin_expanded = sin_expanded.repeat(tile_count, 1)
cos[:, col_offset : col_offset + dim_i_half] = cos_expanded
sin[:, col_offset : col_offset + dim_i_half] = sin_expanded
col_offset += dim_i_half
return cos.float(), sin.float()
def _to_tuple(x: int | tuple[int, ...], dim: int = 2) -> tuple[int, ...]:
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(
start: int | tuple[int, ...],
*args: int | tuple[int, ...],
dim: int = 2,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = tuple(stop[i] - start[i] for i in range(dim))
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=dtype, device=device)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
def get_1d_rotary_pos_embed(
dim: int,
pos: torch.FloatTensor | int,
theta: float = 10000.0,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
dtype: torch.dtype = torch.float32,
device: torch.device | str | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
interpolation_factor (float, optional): Factor to scale positions. Defaults to 1.0.
Returns:
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos, dtype=dtype, device=device)
elif (
isinstance(pos, torch.Tensor)
and device is not None
and pos.device != torch.device(device)
):
# Ensure positions are on the requested device to avoid implicit CPU ops.
pos = pos.to(device)
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (
theta
** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].to(dtype) / dim).to(
device=device
)
) # [D/2]
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
freqs_cos = freqs.cos() # [S, D/2]
freqs_sin = freqs.sin() # [S, D/2]
return freqs_cos, freqs_sin
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
theta_rescale_factor: float | list[float] = 1.0,
interpolation_factor: float | list[float] = 1.0,
shard_dim: int = 0,
sp_rank: int = 0,
sp_world_size: int = 1,
dtype: torch.dtype = torch.float32,
start_frame: int = 0,
device: torch.device | str | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Supports sequence parallelism by allowing sharding of a specific dimension.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
interpolation_factor (float): Factor to scale positions. Defaults to 1.0.
shard_dim (int): Which dimension to shard for sequence parallelism. Defaults to 0.
sp_rank (int): Rank in the sequence parallel group. Defaults to 0.
sp_world_size (int): World size of the sequence parallel group. Defaults to 1.
Returns:
Tuple[torch.Tensor, torch.Tensor]: (cos, sin) tensors of shape [HW, D/2]
"""
# Determine per-axis sizes for the (possibly sharded) grid without materializing it
ndim = len(rope_dim_list)
if len(args) == 0:
# start is grid_size
sizes = _to_tuple(start, dim=ndim)
starts = (0,) * ndim
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
starts = _to_tuple(start, dim=ndim)
stops = _to_tuple(args[0], dim=ndim)
sizes = tuple(stops[i] - starts[i] for i in range(ndim))
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
starts = _to_tuple(start, dim=ndim)
_ = _to_tuple(args[0], dim=ndim) # stop, unused here
sizes = _to_tuple(args[1], dim=ndim)
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
assert (
shard_dim < ndim
), f"shard_dim {shard_dim} must be less than number of dimensions {ndim}"
# Apply sequence parallel sharding to the sizes and compute shard offset
shard_sizes = list(sizes)
shard_offsets = [0] * ndim
if sp_world_size > 1:
assert sizes[shard_dim] % sp_world_size == 0, (
f"Dimension {shard_dim} with size {sizes[shard_dim]} is not divisible "
f"by sequence parallel world size {sp_world_size}"
)
shard_size = sizes[shard_dim] // sp_world_size
shard_offsets[shard_dim] = sp_rank * shard_size
shard_sizes[shard_dim] = shard_size
# Handle theta scaling/interpolation factor per-axis
if isinstance(theta_rescale_factor, int | float):
theta_rescale_factor = [theta_rescale_factor] * ndim
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * ndim
assert (
len(theta_rescale_factor) == ndim
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int | float):
interpolation_factor = [interpolation_factor] * ndim
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * ndim
assert (
len(interpolation_factor) == ndim
), "len(interpolation_factor) should equal to len(rope_dim_list)"
# Pre-allocate outputs on the requested device to avoid CPU ops and extra cats
num_tokens = 1
for s in shard_sizes:
num_tokens *= int(s)
head_dim_half = sum(rope_dim_list) // 2
cos = torch.empty((num_tokens, head_dim_half), device=device, dtype=dtype)
sin = torch.empty((num_tokens, head_dim_half), device=device, dtype=dtype)
# Compute per-axis 1D embeddings once and expand via repeats to [N, d_i/2]
col_offset = 0
for i in range(ndim):
dim_i = int(rope_dim_list[i])
dim_i_half = dim_i // 2
size_i = int(shard_sizes[i])
# Starting position for this axis, with optional frame offset for time axis (i==0)
base_offset = starts[i]
if i == 0 and start_frame > 0:
base_offset += start_frame
if sp_world_size > 1 and i == shard_dim:
base_offset += shard_offsets[i]
pos_i = torch.arange(size_i, device=device, dtype=dtype) + base_offset
cos_1d, sin_1d = get_1d_rotary_pos_embed(
dim_i,
pos_i,
theta=theta,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
dtype=dtype,
device=device,
) # [size_i, dim_i/2]
# Expand to [num_tokens, dim_i/2] matching flatten order (last dims vary fastest)
repeats_per_entry = 1
for j in range(i + 1, ndim):
repeats_per_entry *= int(shard_sizes[j])
tile_count = 1
for j in range(0, i):
tile_count *= int(shard_sizes[j])
cos_expanded = cos_1d.repeat_interleave(repeats_per_entry, dim=0)
sin_expanded = sin_1d.repeat_interleave(repeats_per_entry, dim=0)
if tile_count > 1:
cos_expanded = cos_expanded.repeat(tile_count, 1)
sin_expanded = sin_expanded.repeat(tile_count, 1)
cos[:, col_offset : col_offset + dim_i_half] = cos_expanded
sin[:, col_offset : col_offset + dim_i_half] = sin_expanded
col_offset += dim_i_half
return cos, sin
def get_rotary_pos_embed(
rope_sizes,
hidden_size,
heads_num,
rope_dim_list,
rope_theta,
theta_rescale_factor=1.0,
interpolation_factor=1.0,
shard_dim: int = 0,
dtype: torch.dtype = torch.float32,
start_frame: int = 0,
device: torch.device | str | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Generate rotary positional embeddings for the given sizes.
Args:
rope_sizes: Tuple of dimensions (t, h, w)
hidden_size: Hidden dimension size
heads_num: Number of attention heads
rope_dim_list: List of dimensions for each axis, or None
rope_theta: Base for frequency calculations
theta_rescale_factor: Rescale factor for theta. Defaults to 1.0
interpolation_factor: Factor to scale positions. Defaults to 1.0
shard_dim: Which dimension to shard for sequence parallelism. Defaults to 0.
Returns:
Tuple of (cos, sin) tensors for rotary embeddings
"""
target_ndim = 3
head_dim = hidden_size // heads_num
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert (
sum(rope_dim_list) == head_dim
), "sum(rope_dim_list) should equal to head_dim of attention layer"
# Get SP info - now handled within NDRotaryEmbedding
# sp_group = get_sp_group()
# sp_rank = sp_group.rank_in_group
# sp_world_size = sp_group.world_size
# Simple LRU cache keyed by parameters
global _ND_ROPE_CACHE
key = (
tuple(rope_dim_list),
float(rope_theta),
(
tuple(theta_rescale_factor)
if isinstance(theta_rescale_factor, list)
else float(theta_rescale_factor)
),
(
tuple(interpolation_factor)
if isinstance(interpolation_factor, list)
else float(interpolation_factor)
),
dtype,
)
cache_hit = key in _ND_ROPE_CACHE
if cache_hit:
rope_emb = _ND_ROPE_CACHE.pop(key)
_ND_ROPE_CACHE[key] = rope_emb # move to end (most-recent)
else:
rope_emb = NDRotaryEmbedding(
rope_dim_list=rope_dim_list,
rope_theta=rope_theta,
theta_rescale_factor=theta_rescale_factor,
interpolation_factor=interpolation_factor,
dtype=dtype,
)
_ND_ROPE_CACHE[key] = rope_emb
if len(_ND_ROPE_CACHE) > 16:
# pop least-recently-used
_ND_ROPE_CACHE.pop(next(iter(_ND_ROPE_CACHE)))
freqs_cos, freqs_sin = rope_emb.forward_from_grid(
grid_size=_to_tuple(rope_sizes, dim=3),
shard_dim=shard_dim,
start_frame=start_frame,
device=device,
)
return freqs_cos, freqs_sin
_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
_ND_ROPE_CACHE: "OrderedDict[tuple, NDRotaryEmbedding]" = OrderedDict()
_ROPE_3D_CACHE: "OrderedDict[tuple, tuple[torch.Tensor, torch.Tensor]]" = OrderedDict()
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int | float,
is_neox_style: bool = True,
rope_scaling: dict[str, Any] | None = None,
dtype: torch.dtype | None = None,
partial_rotary_factor: float = 1.0,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling_args,
dtype,
)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
else:
raise ValueError(f"Unknown RoPE scaling {rope_scaling}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# TODO: for temporary usage, expecting a refactor
from typing import Optional
import torch
import triton # type: ignore
import triton.language as tl # type: ignore
from torch import Tensor
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 64}, num_warps=2),
triton.Config({"BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_N": 256}, num_warps=4),
triton.Config({"BLOCK_N": 512}, num_warps=4),
triton.Config({"BLOCK_N": 1024}, num_warps=8),
],
key=["inner_dim"],
)
@triton.jit
def _fused_scale_shift_4d_kernel(
output_ptr,
normalized_ptr,
scale_ptr,
shift_ptr,
rows,
inner_dim,
seq_len,
num_frames,
frame_seqlen,
BLOCK_N: tl.constexpr,
):
pid_row = tl.program_id(0)
pid_col = tl.program_id(1)
col_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N)
mask = col_offsets < inner_dim
# Pointers for normalized and output
row_base = pid_row * inner_dim
norm_ptrs = normalized_ptr + row_base + col_offsets
out_ptrs = output_ptr + row_base + col_offsets
# Pointers for scale and shift for 4D
b_idx = pid_row // seq_len
t_idx = pid_row % seq_len
frame_idx_in_batch = t_idx // frame_seqlen
scale_row_idx = b_idx * num_frames + frame_idx_in_batch
scale_ptrs = scale_ptr + scale_row_idx * inner_dim + col_offsets
shift_ptrs = shift_ptr + scale_row_idx * inner_dim + col_offsets
normalized = tl.load(norm_ptrs, mask=mask, other=0.0)
scale = tl.load(scale_ptrs, mask=mask, other=0.0)
shift = tl.load(shift_ptrs, mask=mask, other=0.0)
one = tl.full([BLOCK_N], 1.0, dtype=scale.dtype)
output = normalized * (one + scale) + shift
tl.store(out_ptrs, output, mask=mask)
@triton.jit
def fuse_scale_shift_kernel_blc_opt(
x_ptr,
shift_ptr,
scale_ptr,
y_ptr,
B,
L,
C,
stride_x_b,
stride_x_l,
stride_x_c,
stride_s_b,
stride_s_l,
stride_s_c,
stride_sc_b,
stride_sc_l,
stride_sc_c,
SCALE_IS_SCALAR: tl.constexpr,
SHIFT_IS_SCALAR: tl.constexpr,
BLOCK_L: tl.constexpr,
BLOCK_C: tl.constexpr,
):
pid_l = tl.program_id(0)
pid_c = tl.program_id(1)
pid_b = tl.program_id(2)
l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)
c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)
mask_l = l_offsets < L
mask_c = c_offsets < C
mask = mask_l[:, None] & mask_c[None, :]
x_off = (
pid_b * stride_x_b
+ l_offsets[:, None] * stride_x_l
+ c_offsets[None, :] * stride_x_c
)
x = tl.load(x_ptr + x_off, mask=mask, other=0)
if SHIFT_IS_SCALAR:
shift_val = tl.load(shift_ptr)
shift = tl.full((BLOCK_L, BLOCK_C), shift_val, dtype=shift_val.dtype)
else:
s_off = (
pid_b * stride_s_b
+ l_offsets[:, None] * stride_s_l
+ c_offsets[None, :] * stride_s_c
)
shift = tl.load(shift_ptr + s_off, mask=mask, other=0)
if SCALE_IS_SCALAR:
scale_val = tl.load(scale_ptr)
scale = tl.full((BLOCK_L, BLOCK_C), scale_val, dtype=scale_val.dtype)
else:
sc_off = (
pid_b * stride_sc_b
+ l_offsets[:, None] * stride_sc_l
+ c_offsets[None, :] * stride_sc_c
)
scale = tl.load(scale_ptr + sc_off, mask=mask, other=0)
y = x * (1 + scale) + shift
tl.store(y_ptr + x_off, y, mask=mask)
def fuse_scale_shift_kernel(
x: torch.Tensor,
scale: torch.Tensor,
shift: torch.Tensor,
block_l: int = 128,
block_c: int = 128,
):
assert x.is_cuda and scale.is_cuda
assert x.is_contiguous()
B, L, C = x.shape
output = torch.empty_like(x)
if scale.dim() == 4:
# scale/shift: [B, F, 1, C]
rows = B * L
x_2d = x.view(rows, C)
output_2d = output.view(rows, C)
grid = lambda META: (rows, triton.cdiv(C, META["BLOCK_N"]))
num_frames = scale.shape[1]
assert (
L % num_frames == 0
), "seq_len must be divisible by num_frames for 4D scale/shift"
frame_seqlen = L // num_frames
# Compact [B, F, C] without the singleton dim into [B*F, C]
scale_reshaped = scale.squeeze(2).reshape(-1, C).contiguous()
shift_reshaped = shift.squeeze(2).reshape(-1, C).contiguous()
_fused_scale_shift_4d_kernel[grid](
output_2d,
x_2d,
scale_reshaped,
shift_reshaped,
rows,
C,
L,
num_frames,
frame_seqlen,
)
else:
# 2D: [B, C] or [1, C] -> treat as [B, 1, C] and broadcast over L
# 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C])
# Also support scalar (0D or 1-element)
if scale.dim() == 0 or (scale.dim() == 1 and scale.numel() == 1):
scale_blc = scale.reshape(1)
elif scale.dim() == 2:
scale_blc = scale[:, None, :]
elif scale.dim() == 3:
scale_blc = scale
else:
raise ValueError("scale must be 0D/1D(1)/2D/3D or 4D")
if shift.dim() == 0 or (shift.dim() == 1 and shift.numel() == 1):
shift_blc = shift.reshape(1)
elif shift.dim() == 2:
shift_blc = shift[:, None, :]
elif shift.dim() == 3:
shift_blc = shift
else:
# broadcast later via expand if possible
shift_blc = shift
need_scale_scalar = scale_blc.dim() == 1 and scale_blc.numel() == 1
need_shift_scalar = shift_blc.dim() == 1 and shift_blc.numel() == 1
if not need_scale_scalar:
scale_exp = scale_blc.expand(B, L, C)
s_sb, s_sl, s_sc = scale_exp.stride()
else:
s_sb = s_sl = s_sc = 0
if not need_shift_scalar:
shift_exp = shift_blc.expand(B, L, C)
sh_sb, sh_sl, sh_sc = shift_exp.stride()
else:
sh_sb = sh_sl = sh_sc = 0
# If both scalars and both zero, copy fast-path
if need_scale_scalar and need_shift_scalar:
if (scale_blc.abs().max() == 0) and (shift_blc.abs().max() == 0):
output.copy_(x)
return output
grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B)
fuse_scale_shift_kernel_blc_opt[grid](
x,
shift_blc if need_shift_scalar else shift_exp,
scale_blc if need_scale_scalar else scale_exp,
output,
B,
L,
C,
x.stride(0),
x.stride(1),
x.stride(2),
sh_sb,
sh_sl,
sh_sc,
s_sb,
s_sl,
s_sc,
SCALE_IS_SCALAR=need_scale_scalar,
SHIFT_IS_SCALAR=need_shift_scalar,
BLOCK_L=block_l,
BLOCK_C=block_c,
num_warps=4,
num_stages=2,
)
return output
@triton.autotune(
configs=[
triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2),
triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8),
],
key=["head_size", "interleaved"],
)
@triton.jit
def _rotary_embedding_kernel(
output_ptr,
x_ptr,
cos_ptr,
sin_ptr,
num_heads,
head_size,
num_tokens,
stride_x_row,
stride_cos_row,
stride_sin_row,
interleaved: tl.constexpr,
BLOCK_HS_HALF: tl.constexpr,
):
row_idx = tl.program_id(0)
token_idx = (row_idx // num_heads) % num_tokens
x_row_ptr = x_ptr + row_idx * stride_x_row
cos_row_ptr = cos_ptr + token_idx * stride_cos_row
sin_row_ptr = sin_ptr + token_idx * stride_sin_row
output_row_ptr = output_ptr + row_idx * stride_x_row
# half size for x1 and x2
head_size_half = head_size // 2
for block_start in range(0, head_size_half, BLOCK_HS_HALF):
offsets_half = block_start + tl.arange(0, BLOCK_HS_HALF)
mask = offsets_half < head_size_half
cos_vals = tl.load(cos_row_ptr + offsets_half, mask=mask, other=0.0)
sin_vals = tl.load(sin_row_ptr + offsets_half, mask=mask, other=0.0)
offsets_x1 = 2 * offsets_half
offsets_x2 = 2 * offsets_half + 1
x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask, other=0.0)
x2_vals = tl.load(x_row_ptr + offsets_x2, mask=mask, other=0.0)
x1_fp32 = x1_vals.to(tl.float32)
x2_fp32 = x2_vals.to(tl.float32)
cos_fp32 = cos_vals.to(tl.float32)
sin_fp32 = sin_vals.to(tl.float32)
o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32)
o2_vals = tl.fma(x1_fp32, sin_fp32, x2_fp32 * cos_fp32)
tl.store(output_row_ptr + offsets_x1, o1_vals.to(x1_vals.dtype), mask=mask)
tl.store(output_row_ptr + offsets_x2, o2_vals.to(x2_vals.dtype), mask=mask)
def apply_rotary_embedding(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
output = torch.empty_like(x)
if x.dim() > 3:
bsz, num_tokens, num_heads, head_size = x.shape
else:
num_tokens, num_heads, head_size = x.shape
bsz = 1
assert head_size % 2 == 0, "head_size must be divisible by 2"
x_reshaped = x.view(-1, head_size)
output_reshaped = output.view(-1, head_size)
# num_tokens per head, 1 token per block
grid = (bsz * num_tokens * num_heads,)
if interleaved and cos.shape[-1] == head_size:
cos = cos[..., ::2].contiguous()
sin = sin[..., ::2].contiguous()
else:
cos = cos.contiguous()
sin = sin.contiguous()
_rotary_embedding_kernel[grid](
output_reshaped,
x_reshaped,
cos,
sin,
num_heads,
head_size,
num_tokens,
x_reshaped.stride(0),
cos.stride(0),
sin.stride(0),
interleaved,
)
return output
# RMSNorm-fp32
def maybe_contiguous_lastdim(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def maybe_contiguous(x):
return x.contiguous() if x is not None else None
def triton_autotune_configs():
# Return configs with a valid warp count for the current device
configs = []
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
max_threads_per_block = 1024
# Default to warp size 32 if not defined by device
warp_size = getattr(
torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32
)
# Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
return [
triton.Config({}, num_warps=warp_count)
for warp_count in [1, 2, 4, 8, 16, 32]
if warp_count * warp_size <= max_threads_per_block
]
# return [triton.Config({}, num_warps=8)]
# Copied from flash-attn
@triton.autotune(
configs=triton_autotune_configs(),
key=[
"N",
"HAS_RESIDUAL",
"STORE_RESIDUAL_OUT",
"IS_RMS_NORM",
"HAS_BIAS",
"HAS_WEIGHT",
"HAS_X1",
"HAS_W1",
"HAS_B1",
],
)
# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
RESIDUAL, # pointer to the residual
X1,
W1,
B1,
Y1,
RESIDUAL_OUT, # pointer to the residual
ROWSCALE,
SEEDS, # Dropout seeds for each row
DROPOUT_MASK,
DROPOUT_MASK1,
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_res_row,
stride_res_out_row,
stride_x1_row,
stride_y1_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
dropout_p, # Dropout probability
zero_centered_weight, # If true, add 1.0 to the weight
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_DROPOUT: tl.constexpr,
STORE_DROPOUT_MASK: tl.constexpr,
HAS_ROWSCALE: tl.constexpr,
HAS_X1: tl.constexpr,
HAS_W1: tl.constexpr,
HAS_B1: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
if HAS_X1:
X1 += row * stride_x1_row
if HAS_W1:
Y1 += row * stride_y1_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
x *= rowscale
if HAS_DROPOUT:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask = (
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
)
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
if HAS_X1:
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
x1 *= rowscale
if HAS_DROPOUT:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask = (
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
> dropout_p
)
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N)
x += x1
if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
x += residual
if STORE_RESIDUAL_OUT:
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
if zero_centered_weight:
w += 1.0
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
if HAS_WEIGHT:
y = x_hat * w + b if HAS_BIAS else x_hat * w
else:
y = x_hat + b if HAS_BIAS else x_hat
# Write output
tl.store(Y + cols, y, mask=mask)
if HAS_W1:
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
if zero_centered_weight:
w1 += 1.0
if HAS_B1:
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
tl.store(Y1 + cols, y1, mask=mask)
def _layer_norm_fwd(
x: Tensor,
weight: Tensor,
bias: Tensor,
eps: float,
residual: Optional[Tensor] = None,
x1: Optional[Tensor] = None,
weight1: Optional[Tensor] = None,
bias1: Optional[Tensor] = None,
dropout_p: float = 0.0,
rowscale: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
residual_dtype: Optional[torch.dtype] = None,
zero_centered_weight: bool = False,
is_rms_norm: bool = False,
return_dropout_mask: bool = False,
out: Optional[Tensor] = None,
residual_out: Optional[Tensor] = None,
) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
# Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
# and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
# so that _layer_norm_fwd_impl doesn't have to return them.
if out is None:
out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
if residual is not None:
residual_dtype = residual.dtype
if residual_out is None and (
residual is not None
or (residual_dtype is not None and residual_dtype != x.dtype)
or dropout_p > 0.0
or rowscale is not None
or x1 is not None
):
residual_out = torch.empty_like(
x, dtype=residual_dtype if residual_dtype is not None else x.dtype
)
else:
residual_out = None
y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl(
x,
weight,
bias,
eps,
out,
residual=residual,
x1=x1,
weight1=weight1,
bias1=bias1,
dropout_p=dropout_p,
rowscale=rowscale,
zero_centered_weight=zero_centered_weight,
is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask,
residual_out=residual_out,
)
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
if residual_out is None:
residual_out = x
return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1
# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema
# since we're returning a tuple of tensors
def _layer_norm_fwd_impl(
x: Tensor,
weight: Optional[Tensor],
bias: Tensor,
eps: float,
out: Tensor,
residual: Optional[Tensor] = None,
x1: Optional[Tensor] = None,
weight1: Optional[Tensor] = None,
bias1: Optional[Tensor] = None,
dropout_p: float = 0.0,
rowscale: Optional[Tensor] = None,
zero_centered_weight: bool = False,
is_rms_norm: bool = False,
return_dropout_mask: bool = False,
residual_out: Optional[Tensor] = None,
) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
M, N = x.shape
assert x.stride(-1) == 1
if residual is not None:
assert residual.stride(-1) == 1
assert residual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
if x1 is not None:
assert x1.shape == x.shape
assert rowscale is None
assert x1.stride(-1) == 1
if weight1 is not None:
assert weight1.shape == (N,)
assert weight1.stride(-1) == 1
if bias1 is not None:
assert bias1.shape == (N,)
assert bias1.stride(-1) == 1
if rowscale is not None:
assert rowscale.is_contiguous()
assert rowscale.shape == (M,)
assert out.shape == x.shape
assert out.stride(-1) == 1
if residual_out is not None:
assert residual_out.shape == x.shape
assert residual_out.stride(-1) == 1
if weight1 is not None:
y1 = torch.empty_like(out)
assert y1.stride(-1) == 1
else:
y1 = None
mean = (
torch.empty((M,), dtype=torch.float32, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
if dropout_p > 0.0:
seeds = torch.randint(
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
)
else:
seeds = None
if return_dropout_mask and dropout_p > 0.0:
dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool)
if x1 is not None:
dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool)
else:
dropout_mask1 = None
else:
dropout_mask, dropout_mask1 = None, None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
with torch.cuda.device(x.device.index):
torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)](
x,
out,
weight if weight is not None else x, # unused when HAS_WEIGHT == False
bias,
residual,
x1,
weight1,
bias1,
y1,
residual_out,
rowscale,
seeds,
dropout_mask,
dropout_mask1,
mean,
rstd,
x.stride(0),
out.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
x1.stride(0) if x1 is not None else 0,
y1.stride(0) if y1 is not None else 0,
M,
N,
eps,
dropout_p,
# Passing bool make torch inductor very unhappy since it then tries to compare to int_max
int(zero_centered_weight),
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
weight is not None,
bias is not None,
dropout_p > 0.0,
dropout_mask is not None,
rowscale is not None,
HAS_X1=x1 is not None,
HAS_W1=weight1 is not None,
HAS_B1=bias1 is not None,
)
return y1, mean, rstd, seeds, dropout_mask, dropout_mask1
class LayerNormFn:
@staticmethod
def forward(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
is_rms_norm=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))
if residual is not None:
assert residual.shape == x_shape_og
residual = maybe_contiguous_lastdim(
residual.reshape(-1, residual.shape[-1])
)
if x1 is not None:
assert x1.shape == x_shape_og
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1]))
# weight can be None when elementwise_affine=False for LayerNorm
if weight is not None:
weight = weight.contiguous()
bias = maybe_contiguous(bias)
weight1 = maybe_contiguous(weight1)
bias1 = maybe_contiguous(bias1)
if rowscale is not None:
rowscale = rowscale.reshape(-1).contiguous()
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
if out is not None:
out = out.reshape(-1, out.shape[-1])
if residual_out is not None:
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
_layer_norm_fwd(
x,
weight,
bias,
eps,
residual,
x1,
weight1,
bias1,
dropout_p=dropout_p,
rowscale=rowscale,
out_dtype=out_dtype,
residual_dtype=residual_dtype,
zero_centered_weight=zero_centered_weight,
is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask,
out=out,
residual_out=residual_out,
)
)
y = y.reshape(x_shape_og)
return y
def layer_norm_fn(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
is_rms_norm=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
return LayerNormFn.forward(
x,
weight,
bias,
residual,
x1,
weight1,
bias1,
eps,
dropout_p,
rowscale,
prenorm,
residual_in_fp32,
zero_centered_weight,
is_rms_norm,
return_dropout_mask,
out_dtype,
out,
residual_out,
)
@triton.jit
def _norm_infer_kernel(
X,
Y,
W,
B,
stride_x_row,
stride_y_row,
M,
N,
eps,
IS_RMS_NORM: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
BLOCK_N: tl.constexpr,
):
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_WEIGHT:
W += 0
if HAS_BIAS:
B += 0
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
if HAS_WEIGHT:
w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32)
y = x_hat * w
else:
y = x_hat
if HAS_BIAS:
b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32)
y += b
tl.store(Y + cols, y, mask=cols < N)
def norm_infer(
x: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
eps: float,
is_rms_norm: bool = False,
out: Optional[Tensor] = None,
):
M, N = x.shape
assert x.stride(-1) == 1
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.shape == (N,)
assert bias.stride(-1) == 1
if out is None:
out = torch.empty_like(x)
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
num_warps = min(max(BLOCK_N // 256, 1), 8)
_norm_infer_kernel[(M,)](
x,
out,
weight if weight is not None else x, # dummy when HAS_WEIGHT=False
bias if bias is not None else x, # dummy when HAS_BIAS=False
x.stride(0),
out.stride(0),
M,
N,
eps,
IS_RMS_NORM=is_rms_norm,
HAS_WEIGHT=weight is not None,
HAS_BIAS=bias is not None,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
)
return out
def rms_norm_fn(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
return LayerNormFn.forward(
x,
weight,
bias,
residual,
x1,
weight1,
bias1,
eps,
dropout_p,
rowscale,
prenorm,
residual_in_fp32,
zero_centered_weight,
True,
return_dropout_mask,
out_dtype,
out,
residual_out,
)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import logging
from typing import TYPE_CHECKING
import torch
import torch.distributed._functional_collectives as ft_c
from packaging.version import parse
from torch.distributed.tensor.experimental._attention import _cp_options
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_sp_group,
get_ulysses_parallel_world_size,
)
_cp_options.enable_load_balance = False
if TYPE_CHECKING:
from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import (
AttentionImpl,
)
logger = logging.getLogger(__name__)
def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor:
"""
When tracing the code, the result tensor is not an AsyncCollectiveTensor,
so we cannot call ``wait()``.
"""
if isinstance(tensor, ft_c.AsyncCollectiveTensor):
return tensor.wait()
return tensor
def _usp_all_to_all_single(x: torch.Tensor) -> torch.Tensor:
ulysses_pg = get_sp_group().ulysses_group
assert ulysses_pg is not None, "Ulysses process group is not initialized."
x_shape = x.shape
x = x.flatten()
x = ft_c.all_to_all_single(
x, output_split_sizes=None, input_split_sizes=None, group=ulysses_pg
)
x = _maybe_wait(x)
x = x.reshape(x_shape)
return x
def _usp_input_all_to_all(x: torch.Tensor, head_dim: int = 1) -> torch.Tensor:
"""
Perform Ulysses-style input all-to-all over the head dimension.
Default layout expects heads at dim=1 and sequence at dim=2:
[b, h, s_local, d] -> [b, h // world_size, s_global, d]
If heads are at dim=2 (input is [b, s_local, h, d]), set head_dim=2, and the
function returns [b, s_global, h // world_size, d], preserving the original
head/sequence dim ordering.
Args:
x: A 4D tensor with layout [b, *, *, d] where '*' are sequence and heads
head_dim: Which dimension index corresponds to heads (1 or 2)
Returns:
Tensor with the same dim order as input, with heads sharded and sequence gathered.
"""
world_size = get_ulysses_parallel_world_size()
if world_size <= 1:
return x
assert x.ndim == 4, f"x must have 4 dimensions, got {x.ndim}"
assert head_dim in (1, 2), f"head_dim must be 1 or 2, got {head_dim}"
seq_dim = 1 if head_dim == 2 else 2
# Bring to canonical [b, h, s, d]
if head_dim == 1 and seq_dim == 2:
x_c = x
else:
x_c = x.permute(0, head_dim, seq_dim, 3).contiguous()
b, h, s, d = x_c.shape
assert (
h % world_size == 0
), f"h ({h}) must be divisible by world_size ({world_size})"
# [b, h, s, d] -> [h, b, s, d]
x_c = x_c.permute(1, 0, 2, 3).contiguous()
# all-to-all along h
x_c = _usp_all_to_all_single(x_c)
# -> [b, h // world, s * world, d]
x_c = (
x_c.reshape(world_size, h // world_size, b, -1, d)
.permute(2, 1, 0, 3, 4)
.reshape(b, h // world_size, -1, d)
)
if head_dim == 1 and seq_dim == 2:
return x_c
# Map back to original ordering, preserving head/seq positions
new_order = [0, None, None, 3]
new_order[head_dim] = 1
new_order[seq_dim] = 2
return x_c.permute(tuple(new_order)).contiguous()
def _usp_output_all_to_all(x: torch.Tensor, head_dim: int = 1) -> torch.Tensor:
"""
Perform Ulysses-style output all-to-all over the head dimension (inverse of input).
Default layout expects heads at dim=1 and sequence at dim=2:
[b, h // world_size, s_global, d] -> [b, h, s_local, d]
If heads are at dim=2 (input is [b, s_global, h // world_size, d]), set head_dim=2,
and the function returns [b, s_local, h, d], preserving the original head/sequence
dim ordering.
Args:
x: A 4D tensor with layout [b, *, *, d] where '*' are sequence and heads
head_dim: Which dimension index corresponds to heads (1 or 2)
Returns:
Tensor with the same dim order as input, with heads gathered and sequence sharded.
"""
world_size = get_ulysses_parallel_world_size()
if world_size <= 1:
return x
assert x.ndim == 4, f"x must have 4 dimensions, got {x.ndim}"
assert head_dim in (1, 2), f"head_dim must be 1 or 2, got {head_dim}"
seq_dim = 1 if head_dim == 2 else 2
# Bring to canonical [b, h, s, d]
if head_dim == 1 and seq_dim == 2:
x_c = x
else:
x_c = x.permute(0, head_dim, seq_dim, 3).contiguous()
b, h, s, d = x_c.shape
assert (
s % world_size == 0
), f"s ({s}) must be divisible by world_size ({world_size})"
# [b, h, s, d] -> [s, b, h, d]
x_c = x_c.permute(2, 0, 1, 3).contiguous()
x_c = _usp_all_to_all_single(x_c)
# -> [b, h * world, s // world, d]
x_c = (
x_c.reshape(world_size, s // world_size, b, -1, d)
.permute(2, 0, 3, 1, 4)
.reshape(b, -1, s // world_size, d)
)
if head_dim == 1 and seq_dim == 2:
return x_c
# Map back to original ordering, preserving head/seq positions
new_order = [0, None, None, 3]
new_order[head_dim] = 1
new_order[seq_dim] = 2
return x_c.permute(tuple(new_order)).contiguous()
def ring_attn(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_impl: "AttentionImpl",
is_causal: bool = False,
dropout_p: float = 0.0,
):
"""
Ring Attention implementation.
This function implements Ring Attention, a strategy for distributed attention
computation that reduces peak memory usage. It accepts a generic attention
implementation (`attn_impl`) which is called by the underlying PyTorch
distributed attention primitive.
Args:
query, key, value: The input tensors for attention.
attn_impl: An instance of an attention implementation backend
(e.g., FlashAttentionImpl) whose `forward` method will be
used as the computational kernel.
is_causal: Whether to apply causal masking.
dropout_p: Dropout probability.
"""
# torch.distributed.tensor.experimental._attention is not a public API,
from torch.distributed.tensor.experimental._attention import (
_templated_ring_attention,
)
ring_pg = get_sp_group().ring_group
assert ring_pg is not None, "Ring process group is not initialized."
# Ring attention primitives expect tensors in [B, H, S, D] layout.
# We permute the inputs here.
query = torch.permute(query, [0, 2, 1, 3]).contiguous()
key = torch.permute(key, [0, 2, 1, 3]).contiguous()
value = torch.permute(value, [0, 2, 1, 3]).contiguous()
# Create an adapter function that matches the signature expected by
# _templated_ring_attention. The `attn_impl` already has dropout and
# causal settings configured during its initialization.
# Note: Please be aware that Attention Backend and Ring Attention may require different QKV tensor shapes.
# For example, FlashAttention expects the format to be BSHD.
def attn_callable_adapter(q, k, v, *args, **kwargs):
# We ignore the dropout_p and is_causal passed by _templated_ring_attention
# and rely on the pre-configured attn_impl.
# The `attn_metadata` is not available here, so we pass None.
# This is a limitation we must accept when using this experimental API.
q = torch.permute(q, [0, 2, 1, 3])
k = torch.permute(k, [0, 2, 1, 3])
v = torch.permute(v, [0, 2, 1, 3])
# logger.warning(f"Warning: return_s·oftmax_lse is only supported for FlashAttentionImpl")
output, softmax_lse, *rest = attn_impl.forward(
q,
k,
v,
attn_metadata=None,
return_softmax_lse=True,
)
output = torch.permute(output, [0, 2, 1, 3])
return output, softmax_lse, *rest
# Starting from torch 2.6.0, _templated_ring_attention expects an integer
# segment_id for the attention function.
use_segment_id = parse(torch.__version__).release >= parse("2.6.0").release
attn_kwargs = dict(
mesh=ring_pg,
op=attn_callable_adapter,
dropout_p=dropout_p,
is_causal=is_causal,
query=query,
key=key,
value=value,
)
if use_segment_id:
# For torch >= 2.6, segment_id is required. The value '1' is a placeholder
# as we are not using complex segmentation features.
out, *_ = _templated_ring_attention(
seq_dim=1, # segment_id
**attn_kwargs,
)
else:
out, *_ = _templated_ring_attention(
**attn_kwargs,
)
# Permute the output back to [B, S, H, D] layout.
output = torch.permute(out, [0, 2, 1, 3])
return output
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/utils.py
"""Utility methods for model layers."""
import torch
def get_token_bin_counts_and_mask(
tokens: torch.Tensor,
vocab_size: int,
num_seqs: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# Compute the bin counts for the tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros(
(num_seqs, vocab_size + 1), dtype=torch.long, device=tokens.device
)
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
bin_counts = bin_counts[:, :vocab_size]
mask = bin_counts > 0
return bin_counts, mask
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import math
import torch
import torch.nn as nn
from sglang.multimodal_gen.runtime.layers.activation import get_act_fn
from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear
from sglang.multimodal_gen.runtime.layers.mlp import MLP
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding
Image to Patch Embedding using Conv2d
A convolution based approach to patchifying a 2D image w/ embedding projection.
Based on the impl in https://github.com/google-research/vision_transformer
Hacked together by / Copyright 2020 Ross Wightman
Remove the _assert function in forward function to be compatible with multi-resolution images.
"""
def __init__(
self,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
dtype=None,
prefix: str = "",
):
super().__init__()
# Convert patch_size to 2-tuple
if isinstance(patch_size, list | tuple):
if len(patch_size) == 1:
patch_size = (patch_size[0], patch_size[0])
else:
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.flatten = flatten
self.proj = nn.Conv3d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
dtype=dtype,
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(
self,
hidden_size,
act_layer="silu",
frequency_embedding_size=256,
max_period=10000,
dtype=None,
freq_dtype=torch.float32,
prefix: str = "",
):
super().__init__()
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
self.mlp = MLP(
frequency_embedding_size,
hidden_size,
hidden_size,
act_type=act_layer,
dtype=dtype,
)
self.freq_dtype = freq_dtype
def forward(
self, t: torch.Tensor, timestep_seq_len: int | None = None
) -> torch.Tensor:
t_freq = timestep_embedding(
t, self.frequency_embedding_size, self.max_period, dtype=self.freq_dtype
).to(self.mlp.fc_in.weight.dtype)
if timestep_seq_len is not None:
t_freq = t_freq.unflatten(0, (1, timestep_seq_len))
# t_freq = t_freq.to(self.mlp.fc_in.weight.dtype)
t_emb = self.mlp(t_freq)
return t_emb
def timestep_embedding(
t: torch.Tensor,
dim: int,
max_period: int = 10000,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Create sinusoidal timestep embeddings.
Args:
t: Tensor of shape [B] with timesteps
dim: Embedding dimension
max_period: Controls the minimum frequency of the embeddings
Returns:
Tensor of shape [B, dim] with embeddings
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=dtype, device=t.device)
/ half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class ModulateProjection(nn.Module):
"""Modulation layer for DiT blocks."""
def __init__(
self,
hidden_size: int,
factor: int = 2,
act_layer: str = "silu",
dtype: torch.dtype | None = None,
prefix: str = "",
):
super().__init__()
self.factor = factor
self.hidden_size = hidden_size
self.linear = ReplicatedLinear(
hidden_size, hidden_size * factor, bias=True, params_dtype=dtype
)
self.act = get_act_fn(act_layer)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.act(x)
x, _ = self.linear(x)
return x
def unpatchify(x, t, h, w, patch_size, channels) -> torch.Tensor:
"""
Convert patched representation back to image space.
Args:
x: Tensor of shape [B, T*H*W, C*P_t*P_h*P_w]
t, h, w: Temporal and spatial dimensions
Returns:
Unpatchified tensor of shape [B, C, T*P_t, H*P_h, W*P_w]
"""
assert x.ndim == 3, f"x.ndim: {x.ndim}"
assert len(patch_size) == 3, f"patch_size: {patch_size}"
assert t * h * w == x.shape[1], f"t * h * w: {t * h * w}, x.shape[1]: {x.shape[1]}"
c = channels
pt, ph, pw = patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
from sglang.multimodal_gen.runtime.distributed import (
divide,
get_tp_rank,
get_tp_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.multimodal_gen.runtime.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
method_has_implemented_embedding,
)
from sglang.multimodal_gen.runtime.models.parameter import BasevLLMParameter
from sglang.multimodal_gen.runtime.models.utils import set_weight_attrs
from sglang.multimodal_gen.runtime.platforms import current_platform
DEFAULT_VOCAB_PADDING_SIZE = 64
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings."""
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""Create weights for embedding layer."""
weight = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
return F.embedding(input_, layer.weight)
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank: int, offset: int = 0
) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f + offset, index_l + offset
def vocab_range_from_global_vocab_size(
global_vocab_size: int, rank: int, world_size: int, offset: int = 0
) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, offset=offset
)
@dataclass
class VocabParallelEmbeddingShardIndices:
"""Indices for a shard of a vocab parallel embedding."""
padded_org_vocab_start_index: int
padded_org_vocab_end_index: int
padded_added_vocab_start_index: int
padded_added_vocab_end_index: int
org_vocab_start_index: int
org_vocab_end_index: int
added_vocab_start_index: int
added_vocab_end_index: int
@property
def num_org_elements(self) -> int:
return self.org_vocab_end_index - self.org_vocab_start_index
@property
def num_added_elements(self) -> int:
return self.added_vocab_end_index - self.added_vocab_start_index
@property
def num_org_elements_padded(self) -> int:
return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index
@property
def num_added_elements_padded(self) -> int:
return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index
@property
def num_org_vocab_padding(self) -> int:
return self.num_org_elements_padded - self.num_org_elements
@property
def num_added_vocab_padding(self) -> int:
return self.num_added_elements_padded - self.num_added_elements
@property
def num_elements_padded(self) -> int:
return self.num_org_elements_padded + self.num_added_elements_padded
def __post_init__(self):
# sanity checks
assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index
assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index
assert self.org_vocab_start_index <= self.org_vocab_end_index
assert self.added_vocab_start_index <= self.added_vocab_end_index
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
assert self.added_vocab_start_index <= self.padded_added_vocab_start_index
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
assert self.num_org_elements <= self.num_org_elements_padded
assert self.num_added_elements <= self.num_added_elements_padded
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def get_masked_input_and_mask(
input_: torch.Tensor,
org_vocab_start_index: int,
org_vocab_end_index: int,
num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index
)
added_offset = (
added_vocab_start_index
- (org_vocab_end_index - org_vocab_start_index)
- num_org_vocab_padding
)
valid_offset = (org_vocab_start_index * org_vocab_mask) + (
added_offset * added_vocab_mask
)
vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added
embeddings are always at the end of TP-sharded tensors. In other words,
we shard base embeddings and LoRA embeddings separately (both padded),
and place them in the same tensor.
In this example, we will have the original vocab size = 1010,
added vocab size = 16 and padding to 64. Therefore, the total
vocab size with padding will be 1088 (because we first pad 1010 to
1024, add 16, and then pad to 1088).
Therefore, the tensor format looks like the following:
TP1, rank 0 (no sharding):
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0:
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
TP2, rank 1:
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
prefix: full name of the layer in the state dict
""" # noqa: E501
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
params_dtype: torch.dtype | None = None,
org_num_embeddings: int | None = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
# Keep the input dimensions.
tp_rank = get_tp_rank()
self.tp_size = get_tp_world_size()
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(
self.org_vocab_size, self.padding_size
)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(
self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size,
tp_rank,
self.tp_size,
)
self.embedding_dim = embedding_dim
quant_method = None
if quant_config is not None:
quant_method = quant_config.get_quant_method(self, prefix=prefix)
if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method)
)
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod."
)
self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(
self.num_embeddings_padded, self.tp_size
)
assert (
self.shard_indices.num_elements_padded == self.num_embeddings_per_partition
)
self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index
- self.shard_indices.org_vocab_start_index
)
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index
- self.shard_indices.added_vocab_start_index
)
self.quant_method.create_weights(
self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader,
)
@classmethod
def _get_indices(
cls,
vocab_size_padded: int,
org_vocab_size_padded: int,
vocab_size: int,
org_vocab_size: int,
tp_rank: int,
tp_size: int,
) -> VocabParallelEmbeddingShardIndices:
"""Get start and end indices for vocab parallel embedding, following the
layout outlined in the class docstring, based on the given tp_rank and
tp_size."""
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
padded_org_vocab_start_index, padded_org_vocab_end_index = (
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size)
)
padded_added_vocab_start_index, padded_added_vocab_end_index = (
vocab_range_from_global_vocab_size(
num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
)
)
# remove padding
org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
return VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index,
padded_org_vocab_end_index,
padded_added_vocab_start_index,
padded_added_vocab_end_index,
org_vocab_start_index,
org_vocab_end_index,
added_vocab_start_index,
added_vocab_end_index,
)
def get_sharded_to_full_mapping(self) -> list[int] | None:
"""Get a mapping that can be used to reindex the gathered
logits for sampling.
During sampling, we gather logits from all ranks. The relationship
of index->token_id will follow the same format as outlined in the class
docstring. However, after the gather, we want to reindex the final
logits tensor to map index->token_id one-to-one (the index is always
equal the token_id it corresponds to). The indices returned by this
method allow us to do that.
"""
if self.tp_size < 2:
return None
base_embeddings: list[int] = []
added_embeddings: list[int] = []
padding: list[int] = []
for tp_rank in range(self.tp_size):
shard_indices = self._get_indices(
self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size,
tp_rank,
self.tp_size,
)
range_start = self.num_embeddings_per_partition * tp_rank
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
base_embeddings.extend(
range(range_start, range_start + shard_indices.num_org_elements)
)
padding.extend(
range(
range_start + shard_indices.num_org_elements,
range_start + shard_indices.num_org_elements_padded,
)
)
added_embeddings.extend(
range(
range_start + shard_indices.num_org_elements_padded,
range_start
+ shard_indices.num_org_elements_padded
+ shard_indices.num_added_elements,
)
)
padding.extend(
range(
range_start
+ shard_indices.num_org_elements_padded
+ shard_indices.num_added_elements,
range_start
+ shard_indices.num_org_elements_padded
+ shard_indices.num_added_elements_padded,
)
)
assert (
range_start
+ shard_indices.num_org_elements_padded
+ shard_indices.num_added_elements_padded
== range_end
)
ret = base_embeddings + added_embeddings + padding
assert len(ret) == self.num_embeddings_padded
return ret
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None)
# If the parameter is a gguf weight, then load it directly.
if getattr(param, "is_gguf_weight_type", None):
param.data.copy_(loaded_weight)
param.weight_type = loaded_weight.item()
return
elif isinstance(param, UninitializedParameter):
shape = list(loaded_weight.shape)
if output_dim is not None:
shape[output_dim] = self.num_embeddings_per_partition
param.materialize(tuple(shape), dtype=loaded_weight.dtype)
# If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq).
if output_dim is None:
assert param.data.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
return
# Shard indexes for loading the weight
start_idx = self.shard_indices.org_vocab_start_index
shard_size = self.shard_indices.org_vocab_end_index - start_idx
# If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim:
packed_factor = (
param.packed_factor
if isinstance(param, BasevLLMParameter)
else param.pack_factor
)
assert loaded_weight.shape[output_dim] == (
self.org_vocab_size // param.packed_factor
)
start_idx = start_idx // packed_factor
shard_size = shard_size // packed_factor
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size
# Copy the data. Select chunk corresponding to current shard.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0] :].data.fill_(0)
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
input_,
self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index,
)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self, masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}"
s += f", org_vocab_size={self.org_vocab_size}"
s += f", num_embeddings_padded={self.num_embeddings_padded}"
s += f", tp_size={self.tp_size}"
return s
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import dataclasses
import glob
import json
import os
import time
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable
from copy import deepcopy
from typing import cast
import torch
import torch.distributed as dist
import torch.nn as nn
from safetensors.torch import load_file as safetensors_load_file
from torch.distributed import init_device_mesh
from transformers import AutoImageProcessor, AutoProcessor, AutoTokenizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from sglang.multimodal_gen.configs.models import EncoderConfig
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.loader.fsdp_load import (
maybe_load_fsdp_model,
shard_model,
)
from sglang.multimodal_gen.runtime.loader.utils import set_default_torch_dtype
from sglang.multimodal_gen.runtime.loader.weight_utils import (
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
pt_weights_iterator,
safetensors_weights_iterator,
)
from sglang.multimodal_gen.runtime.models.registry import ModelRegistry
from sglang.multimodal_gen.runtime.platforms import current_platform
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (
get_config,
get_diffusers_config,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import PRECISION_TO_TYPE
logger = init_logger(__name__)
class ComponentLoader(ABC):
"""Base class for loading a specific type of model component."""
def __init__(self, device=None) -> None:
self.device = device
@abstractmethod
def load(self, model_path: str, server_args: ServerArgs, module_name: str):
"""
Load the component based on the model path, architecture, and inference args.
Args:
model_path: Path to the component model
server_args: ServerArgs
Returns:
The loaded component
"""
raise NotImplementedError
@classmethod
def for_module_type(
cls, module_type: str, transformers_or_diffusers: str
) -> "ComponentLoader":
"""
Factory method to create a component loader for a specific module type.
Args:
module_type: Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler")
transformers_or_diffusers: Whether the module is from transformers or diffusers
Returns:
A component loader for the specified module type
"""
# Map of module types to their loader classes and expected library
module_loaders = {
"scheduler": (SchedulerLoader, "diffusers"),
"transformer": (TransformerLoader, "diffusers"),
"transformer_2": (TransformerLoader, "diffusers"),
"vae": (VAELoader, "diffusers"),
"text_encoder": (TextEncoderLoader, "transformers"),
"text_encoder_2": (TextEncoderLoader, "transformers"),
"tokenizer": (TokenizerLoader, "transformers"),
"tokenizer_2": (TokenizerLoader, "transformers"),
"image_processor": (ImageProcessorLoader, "transformers"),
"image_encoder": (ImageEncoderLoader, "transformers"),
"processor": (AutoProcessorLoader, "transformers"),
}
if module_type in module_loaders:
loader_cls, expected_library = module_loaders[module_type]
# Assert that the library matches what's expected for this module type
assert (
transformers_or_diffusers == expected_library
), f"{module_type} must be loaded from {expected_library}, got {transformers_or_diffusers}"
return loader_cls()
# For unknown module types, use a generic loader
logger.warning(
"No specific loader found for module type: %s. Using generic loader.",
module_type,
)
return GenericComponentLoader(transformers_or_diffusers)
class TextEncoderLoader(ComponentLoader):
"""Loader for text encoders."""
@dataclasses.dataclass
class Source:
"""A source for weights."""
model_or_path: str
"""The model ID or path."""
prefix: str = ""
"""A prefix to prepend to all weights."""
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""
allow_patterns_overrides: list[str] | None = None
"""If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights: float = 0.0
counter_after_loading_weights: float = 0.0
def _prepare_weights(
self,
model_name_or_path: str,
fall_back_to_pt: bool,
allow_patterns_overrides: list[str] | None,
) -> tuple[str, list[str], bool]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
# model_name_or_path = (self._maybe_download_from_modelscope(
# model_name_or_path, revision) or model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
assert is_local, "Model path must be a local directory"
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
allow_patterns = ["*.safetensors", "*.bin"]
if fall_back_to_pt:
allow_patterns += ["*.pt"]
if allow_patterns_overrides is not None:
allow_patterns = allow_patterns_overrides
hf_folder = model_name_or_path
hf_weights_files: list[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if use_safetensors:
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file
)
else:
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`"
)
return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator(
self, source: "Source", to_cpu: bool
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path,
source.fall_back_to_pt,
source.allow_patterns_overrides,
)
if use_safetensors:
weights_iterator = safetensors_weights_iterator(
hf_weights_files, to_cpu=to_cpu
)
else:
weights_iterator = pt_weights_iterator(hf_weights_files, to_cpu=to_cpu)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix.
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
def _get_all_weights(
self,
model: nn.Module,
model_path: str,
to_cpu: bool,
) -> Generator[tuple[str, torch.Tensor], None, None]:
primary_weights = TextEncoderLoader.Source(
model_path,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
)
yield from self._get_weights_iterator(primary_weights, to_cpu)
secondary_weights = cast(
Iterable[TextEncoderLoader.Source],
getattr(model, "secondary_weights", ()),
)
for source in secondary_weights:
yield from self._get_weights_iterator(source, to_cpu)
def load(self, model_path: str, server_args: ServerArgs, module_name: str):
"""Load the text encoders based on the model path, and inference args."""
# model_config: PretrainedConfig = get_hf_config(
# model=model_path,
# trust_remote_code=server_args.trust_remote_code,
# revision=server_args.revision,
# model_override_args=None,
# )
diffusers_pretrained_config = get_config(model_path, trust_remote_code=True)
model_config = get_diffusers_config(model=model_path)
model_config.pop("_name_or_path", None)
model_config.pop("transformers_version", None)
model_config.pop("model_type", None)
model_config.pop("tokenizer_class", None)
model_config.pop("torch_dtype", None)
logger.info("HF model config: %s", model_config)
def is_not_first_encoder(module_name):
return "2" in module_name
# TODO(mick): had to throw an exception for different text-encoder arch
if not is_not_first_encoder(module_name):
encoder_config = server_args.pipeline_config.text_encoder_configs[0]
encoder_config.update_model_arch(model_config)
for key, value in diffusers_pretrained_config.__dict__.items():
setattr(encoder_config.arch_config, key, value)
encoder_dtype = server_args.pipeline_config.text_encoder_precisions[0]
else:
assert len(server_args.pipeline_config.text_encoder_configs) == 2
encoder_config = server_args.pipeline_config.text_encoder_configs[1]
encoder_config.update_model_arch(model_config)
encoder_dtype = server_args.pipeline_config.text_encoder_precisions[1]
target_device = get_local_torch_device()
# TODO(will): add support for other dtypes
return self.load_model(
model_path,
encoder_config,
target_device,
server_args,
encoder_dtype,
)
def load_model(
self,
model_path: str,
model_config: EncoderConfig,
target_device: torch.device,
server_args: ServerArgs,
dtype: str = "fp16",
):
use_cpu_offload = (
server_args.text_encoder_cpu_offload
and len(getattr(model_config, "_fsdp_shard_conditions", [])) > 0
)
if server_args.text_encoder_cpu_offload:
target_device = (
torch.device("mps")
if current_platform.is_mps()
else torch.device("cpu")
)
with set_default_torch_dtype(PRECISION_TO_TYPE[dtype]):
with target_device:
architectures = getattr(model_config, "architectures", [])
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
model = model_cls(model_config)
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self._get_all_weights(model, model_path, to_cpu=use_cpu_offload)
)
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights
- self.counter_before_loading_weights,
)
# Explicitly move model to target device after loading weights
model = model.to(target_device)
if use_cpu_offload:
# Disable FSDP for MPS as it's not compatible
if current_platform.is_mps():
logger.info(
"Disabling FSDP sharding for MPS platform as it's not compatible"
)
else:
mesh = init_device_mesh(
"cuda",
mesh_shape=(1, dist.get_world_size()),
mesh_dim_names=("offload", "replicate"),
)
shard_model(
model,
cpu_offload=True,
reshard_after_forward=True,
mesh=mesh["offload"],
fsdp_shard_conditions=model._fsdp_shard_conditions,
pin_cpu_memory=server_args.pin_cpu_memory,
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
# if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)
return model.eval()
class ImageEncoderLoader(TextEncoderLoader):
def load(self, model_path: str, server_args: ServerArgs, *args):
"""Load the text encoders based on the model path, and inference args."""
# model_config: PretrainedConfig = get_hf_config(
# model=model_path,
# trust_remote_code=server_args.trust_remote_code,
# revision=server_args.revision,
# model_override_args=None,
# )
with open(os.path.join(model_path, "config.json")) as f:
model_config = json.load(f)
model_config.pop("_name_or_path", None)
model_config.pop("transformers_version", None)
model_config.pop("torch_dtype", None)
model_config.pop("model_type", None)
logger.info("HF model config: %s", model_config)
encoder_config = server_args.pipeline_config.image_encoder_config
encoder_config.update_model_arch(model_config)
if server_args.image_encoder_cpu_offload:
target_device = (
torch.device("mps")
if current_platform.is_mps()
else torch.device("cpu")
)
else:
target_device = get_local_torch_device()
# TODO(will): add support for other dtypes
return self.load_model(
model_path,
encoder_config,
target_device,
server_args,
server_args.pipeline_config.image_encoder_precision,
)
class ImageProcessorLoader(ComponentLoader):
"""Loader for image processor."""
def load(self, model_path: str, server_args: ServerArgs, *args):
"""Load the image processor based on the model path, and inference args."""
logger.info("Loading image processor from %s", model_path)
image_processor = AutoImageProcessor.from_pretrained(model_path, use_fast=True)
logger.info("Loaded image processor: %s", image_processor.__class__.__name__)
return image_processor
class AutoProcessorLoader(ComponentLoader):
"""Loader for auto processor."""
def load(self, model_path: str, server_args: ServerArgs, *args):
"""Load the image processor based on the model path, and inference args."""
logger.info("Loading auto processor from %s", model_path)
processor = AutoProcessor.from_pretrained(
model_path,
)
logger.info("Loaded auto processor: %s", processor.__class__.__name__)
return processor
class TokenizerLoader(ComponentLoader):
"""Loader for tokenizers."""
def load(self, model_path: str, server_args: ServerArgs, *args):
"""Load the tokenizer based on the model path, and inference args."""
logger.info("Loading tokenizer from %s", model_path)
tokenizer = AutoTokenizer.from_pretrained(
model_path, # "<path to model>/tokenizer"
# in v0, this was same string as encoder_name "ClipTextModel"
# TODO(will): pass these tokenizer kwargs from inference args? Maybe
# other method of config?
padding_size="right",
)
logger.info("Loaded tokenizer: %s", tokenizer.__class__.__name__)
return tokenizer
class VAELoader(ComponentLoader):
"""Loader for VAE."""
def load(self, model_path: str, server_args: ServerArgs, *args):
"""Load the VAE based on the model path, and inference args."""
config = get_diffusers_config(model=model_path)
class_name = config.pop("_class_name")
assert (
class_name is not None
), "Model config does not contain a _class_name attribute. Only diffusers format is supported."
server_args.model_paths["vae"] = model_path
# TODO: abstract these logics
logger.info("HF model config: %s", config)
vae_config = server_args.pipeline_config.vae_config
vae_config.update_model_arch(config)
# NOTE: some post init logics are only available after updated with config
vae_config.post_init()
if server_args.vae_cpu_offload:
target_device = (
torch.device("mps")
if current_platform.is_mps()
else torch.device("cpu")
)
else:
target_device = get_local_torch_device()
with set_default_torch_dtype(
PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]
):
vae_cls, _ = ModelRegistry.resolve_model_cls(class_name)
vae = vae_cls(vae_config).to(target_device)
# Find all safetensors files
safetensors_list = glob.glob(os.path.join(str(model_path), "*.safetensors"))
# TODO(PY)
assert (
len(safetensors_list) == 1
), f"Found {len(safetensors_list)} safetensors files in {model_path}"
loaded = safetensors_load_file(safetensors_list[0])
vae.load_state_dict(
loaded, strict=False
) # We might only load encoder or decoder
return vae.eval()
class TransformerLoader(ComponentLoader):
"""Loader for transformer."""
def load(self, model_path: str, server_args: ServerArgs, *args):
"""Load the transformer based on the model path, and inference args."""
config = get_diffusers_config(model=model_path)
hf_config = deepcopy(config)
cls_name = config.pop("_class_name")
if cls_name is None:
raise ValueError(
"Model config does not contain a _class_name attribute. "
"Only diffusers format is supported."
)
logger.info("transformer cls_name: %s", cls_name)
if server_args.override_transformer_cls_name is not None:
cls_name = server_args.override_transformer_cls_name
logger.info("Overriding transformer cls_name to %s", cls_name)
server_args.model_paths["transformer"] = model_path
# Config from Diffusers supersedes sgl_diffusion's model config
dit_config = server_args.pipeline_config.dit_config
dit_config.update_model_arch(config)
model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)
# Find all safetensors files
safetensors_list = glob.glob(os.path.join(str(model_path), "*.safetensors"))
if not safetensors_list:
raise ValueError(f"No safetensors files found in {model_path}")
# Check if we should use custom initialization weights
custom_weights_path = getattr(
server_args, "init_weights_from_safetensors", None
)
use_custom_weights = False
if use_custom_weights:
logger.info(
"Using custom initialization weights from: %s", custom_weights_path
)
assert (
custom_weights_path is not None
), "Custom initialization weights must be provided"
if os.path.isdir(custom_weights_path):
safetensors_list = glob.glob(
os.path.join(str(custom_weights_path), "*.safetensors")
)
else:
assert custom_weights_path.endswith(
".safetensors"
), "Custom initialization weights must be a safetensors file"
safetensors_list = [custom_weights_path]
logger.info(
"Loading model from %s safetensors files: %s",
len(safetensors_list),
safetensors_list,
)
default_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision]
# Load the model using FSDP loader
logger.info("Loading %s, default_dtype: %s", cls_name, default_dtype)
assert server_args.hsdp_shard_dim is not None
model = maybe_load_fsdp_model(
model_cls=model_cls,
init_params={"config": dit_config, "hf_config": hf_config},
weight_dir_list=safetensors_list,
device=get_local_torch_device(),
hsdp_replicate_dim=server_args.hsdp_replicate_dim,
hsdp_shard_dim=server_args.hsdp_shard_dim,
cpu_offload=server_args.dit_cpu_offload,
pin_cpu_memory=server_args.pin_cpu_memory,
fsdp_inference=server_args.use_fsdp_inference,
# TODO(will): make these configurable
default_dtype=default_dtype,
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
output_dtype=None,
)
total_params = sum(p.numel() for p in model.parameters())
logger.info("Loaded model with %.2fB parameters", total_params / 1e9)
assert (
next(model.parameters()).dtype == default_dtype
), "Model dtype does not match default dtype"
model = model.eval()
return model
class SchedulerLoader(ComponentLoader):
"""Loader for scheduler."""
def load(self, model_path: str, server_args: ServerArgs, *args):
"""Load the scheduler based on the model path, and inference args."""
config = get_diffusers_config(model=model_path)
class_name = config.pop("_class_name")
assert (
class_name is not None
), "Model config does not contain a _class_name attribute. Only diffusers format is supported."
scheduler_cls, _ = ModelRegistry.resolve_model_cls(class_name)
scheduler = scheduler_cls(**config)
if server_args.pipeline_config.flow_shift is not None:
scheduler.set_shift(server_args.pipeline_config.flow_shift)
if server_args.pipeline_config.timesteps_scale is not None:
scheduler.set_timesteps_scale(server_args.pipeline_config.timesteps_scale)
return scheduler
class GenericComponentLoader(ComponentLoader):
"""Generic loader for components that don't have a specific loader."""
def __init__(self, library="transformers") -> None:
super().__init__()
self.library = library
def load(self, model_path: str, server_args: ServerArgs, *args):
"""Load a generic component based on the model path, and inference args."""
logger.warning(
"Using generic loader for %s with library %s", model_path, self.library
)
if self.library == "transformers":
from transformers import AutoModel
model = AutoModel.from_pretrained(
model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
logger.info(
"Loaded generic transformers model: %s", model.__class__.__name__
)
return model
elif self.library == "diffusers":
logger.warning(
"Generic loading for diffusers components is not fully implemented"
)
model_config = get_diffusers_config(model=model_path)
logger.info("Diffusers Model config: %s", model_config)
# This is a placeholder - in a real implementation, you'd need to handle this properly
return None
else:
raise ValueError(f"Unsupported library: {self.library}")
class PipelineComponentLoader:
"""
Utility class for loading pipeline components.
This replaces the chain of if-else statements in load_pipeline_module.
"""
@staticmethod
def load_module(
module_name: str,
component_model_path: str,
transformers_or_diffusers: str,
server_args: ServerArgs,
):
"""
Load a pipeline module.
Args:
module_name: Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler")
component_model_path: Path to the component model
transformers_or_diffusers: Whether the module is from transformers or diffusers
Returns:
The loaded module
"""
logger.info(
"Loading %s using %s from %s",
module_name,
transformers_or_diffusers,
component_model_path,
)
# Get the appropriate loader for this module type
loader = ComponentLoader.for_module_type(module_name, transformers_or_diffusers)
try:
# Load the module
return loader.load(component_model_path, server_args, module_name)
except Exception as e:
logger.error(
f"Error while loading component: {module_name}, {component_model_path=}"
)
raise e
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from torchtune
# Copyright 2024 The TorchTune Authors.
# Copyright 2025 The sgl-diffusion Authors.
import contextlib
from collections.abc import Callable, Generator
from itertools import chain
from typing import Any
import torch
from torch import nn
from torch.distributed import DeviceMesh, init_device_mesh
from torch.distributed._tensor import distribute_tensor
from torch.distributed.fsdp import (
CPUOffloadPolicy,
FSDPModule,
MixedPrecisionPolicy,
fully_shard,
)
from torch.nn.modules.module import _IncompatibleKeys
from sglang.multimodal_gen.runtime.loader.utils import (
get_param_names_mapping,
hf_to_custom_state_dict,
)
from sglang.multimodal_gen.runtime.loader.weight_utils import (
safetensors_weights_iterator,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import set_mixed_precision_policy
logger = init_logger(__name__)
# TODO(PY): move this to utils elsewhere
@contextlib.contextmanager
def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
"""
Context manager to set torch's default dtype.
Args:
dtype (torch.dtype): The desired default dtype inside the context manager.
Returns:
ContextManager: context manager for setting default dtype.
Example:
>>> with set_default_dtype(torch.bfloat16):
>>> x = torch.tensor([1, 2, 3])
>>> x.dtype
torch.bfloat16
"""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(old_dtype)
# TODO(PY): add compile option
def maybe_load_fsdp_model(
model_cls: type[nn.Module],
init_params: dict[str, Any],
weight_dir_list: list[str],
device: torch.device,
hsdp_replicate_dim: int,
hsdp_shard_dim: int,
default_dtype: torch.dtype,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
cpu_offload: bool = False,
fsdp_inference: bool = False,
output_dtype: torch.dtype | None = None,
pin_cpu_memory: bool = True,
) -> torch.nn.Module:
"""
Load the model with FSDP if is training, else load the model without FSDP.
"""
# NOTE(will): cast_forward_inputs=True shouldn't be needed as we are
# manually casting the inputs to the model
mp_policy = MixedPrecisionPolicy(
param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=False
)
set_mixed_precision_policy(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
output_dtype=output_dtype,
mp_policy=mp_policy,
)
with set_default_dtype(default_dtype), torch.device("meta"):
model = model_cls(**init_params)
# Check if we should use FSDP
use_fsdp = fsdp_inference
# Disable FSDP for MPS as it's not compatible
from sglang.multimodal_gen.runtime.platforms import current_platform
if current_platform.is_mps():
use_fsdp = False
logger.info("Disabling FSDP for MPS platform as it's not compatible")
if use_fsdp:
world_size = hsdp_replicate_dim * hsdp_shard_dim
if not fsdp_inference:
hsdp_replicate_dim = world_size
hsdp_shard_dim = 1
device_mesh = init_device_mesh(
"cuda",
# (Replicate(), Shard(dim=0))
mesh_shape=(hsdp_replicate_dim, hsdp_shard_dim),
mesh_dim_names=("replicate", "shard"),
)
shard_model(
model,
cpu_offload=cpu_offload,
reshard_after_forward=True,
mp_policy=mp_policy,
mesh=device_mesh,
fsdp_shard_conditions=model._fsdp_shard_conditions,
pin_cpu_memory=pin_cpu_memory,
)
weight_iterator = safetensors_weights_iterator(weight_dir_list)
param_names_mapping_fn = get_param_names_mapping(model.param_names_mapping)
load_model_from_full_model_state_dict(
model,
weight_iterator,
device,
default_dtype,
strict=True,
cpu_offload=cpu_offload,
param_names_mapping=param_names_mapping_fn,
)
for n, p in chain(model.named_parameters(), model.named_buffers()):
if p.is_meta:
raise RuntimeError(f"Unexpected param or buffer {n} on meta device.")
# Avoid unintended computation graph accumulation during inference
if isinstance(p, torch.nn.Parameter):
p.requires_grad = False
return model
def shard_model(
model,
*,
cpu_offload: bool,
reshard_after_forward: bool = True,
mp_policy: MixedPrecisionPolicy | None = MixedPrecisionPolicy(), # noqa
mesh: DeviceMesh | None = None,
fsdp_shard_conditions: list[Callable[[str, nn.Module], bool]] = [], # noqa
pin_cpu_memory: bool = True,
) -> None:
"""
Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.
This method will over the model's named modules from the bottom-up and apply shard modules
based on whether they meet any of the criteria from shard_conditions.
Args:
model (TransformerDecoder): Model to shard with FSDP.
cpu_offload (bool): If set to True, FSDP will offload parameters, gradients, and optimizer
states to CPU.
reshard_after_forward (bool): Whether to reshard parameters and buffers after
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under multiple parallelism.
Default to None.
fsdp_shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine
which modules to shard with FSDP.
pin_cpu_memory (bool): If set to True, FSDP will pin the CPU memory of the offloaded parameters.
Raises:
ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
"""
if fsdp_shard_conditions is None or len(fsdp_shard_conditions) == 0:
logger.warning(
"The FSDP shard condition list is empty or None. No modules will be sharded in %s",
type(model).__name__,
)
return
fsdp_kwargs = {
"reshard_after_forward": reshard_after_forward,
"mesh": mesh,
"mp_policy": mp_policy,
}
if cpu_offload:
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy(pin_memory=pin_cpu_memory)
# iterating in reverse to start with
# lowest-level modules first
num_layers_sharded = 0
# TODO(will): don't reshard after forward for the last layer to save on the
# all-gather that will immediately happen Shard the model with FSDP,
for n, m in reversed(list(model.named_modules())):
if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1
if num_layers_sharded == 0:
raise ValueError(
"No layer modules were sharded. Please check if shard conditions are working as expected."
)
# Finally shard the entire model to account for any stragglers
fully_shard(model, **fsdp_kwargs)
# TODO(PY): device mesh for cfg parallel
def load_model_from_full_model_state_dict(
model: FSDPModule | torch.nn.Module,
full_sd_iterator: Generator[tuple[str, torch.Tensor], None, None],
device: torch.device,
param_dtype: torch.dtype,
strict: bool = False,
cpu_offload: bool = False,
param_names_mapping: Callable[[str], tuple[str, Any, Any]] | None = None,
) -> _IncompatibleKeys:
"""
Converting full state dict into a sharded state dict
and loading it into FSDP model (if training) or normal huggingface model
Args:
model (Union[FSDPModule, torch.nn.Module]): Model to generate fully qualified names for cpu_state_dict
full_sd_iterator (Generator): an iterator yielding (param_name, tensor) pairs
device (torch.device): device used to move full state dict tensors
param_dtype (torch.dtype): dtype used to move full state dict tensors
strict (bool): flag to check if to load the model in strict mode
cpu_offload (bool): flag to check if FSDP offload is enabled
param_names_mapping (Optional[Callable[[str], str]]): a function that maps full param name to sharded param name
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
Raises:
NotImplementedError: If got FSDP with more than 1D.
"""
meta_sd = model.state_dict()
sharded_sd = {}
custom_param_sd, reverse_param_names_mapping = hf_to_custom_state_dict(
full_sd_iterator, param_names_mapping
) # type: ignore
for target_param_name, full_tensor in custom_param_sd.items():
meta_sharded_param = meta_sd.get(target_param_name)
if meta_sharded_param is None:
raise ValueError(
f"Parameter {target_param_name} not found in custom model state dict. The hf to custom mapping may be incorrect."
)
if not hasattr(meta_sharded_param, "device_mesh"):
full_tensor = full_tensor.to(device=device, dtype=param_dtype)
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor = full_tensor
else:
full_tensor = full_tensor.to(device=device, dtype=param_dtype)
sharded_tensor = distribute_tensor(
full_tensor,
meta_sharded_param.device_mesh,
meta_sharded_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[target_param_name] = nn.Parameter(sharded_tensor)
model.reverse_param_names_mapping = reverse_param_names_mapping
unused_keys = set(meta_sd.keys()) - set(sharded_sd.keys())
if unused_keys:
logger.warning("Found unloaded parameters in meta state dict: %s", unused_keys)
# List of allowed parameter name patterns
ALLOWED_NEW_PARAM_PATTERNS = ["gate_compress"] # Can be extended as needed
for new_param_name in unused_keys:
if not any(pattern in new_param_name for pattern in ALLOWED_NEW_PARAM_PATTERNS):
logger.error(
"Unsupported new parameter: %s. Allowed patterns: %s",
new_param_name,
ALLOWED_NEW_PARAM_PATTERNS,
)
raise ValueError(
f"New parameter '{new_param_name}' is not supported. "
f"Currently only parameters containing {ALLOWED_NEW_PARAM_PATTERNS} are allowed."
)
meta_sharded_param = meta_sd.get(new_param_name)
if not hasattr(meta_sharded_param, "device_mesh"):
# Initialize with zeros
sharded_tensor = torch.zeros_like(
meta_sharded_param, device=device, dtype=param_dtype
)
else:
# Initialize with zeros and distribute
full_tensor = torch.zeros_like(
meta_sharded_param, device=device, dtype=param_dtype
)
sharded_tensor = distribute_tensor(
full_tensor,
meta_sharded_param.device_mesh,
meta_sharded_param.placements,
)
if cpu_offload:
sharded_tensor = sharded_tensor.cpu()
sharded_sd[new_param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return model.load_state_dict(sharded_sd, strict=strict, assign=True)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""Utilities for selecting and loading models."""
import contextlib
import re
from collections import defaultdict
from collections.abc import Callable, Iterator
from typing import Any
import torch
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def get_param_names_mapping(
mapping_dict: dict[str, str]
) -> Callable[[str], tuple[str, Any, Any]]:
"""
Creates a mapping function that transforms parameter names using regex patterns.
Args:
mapping_dict (Dict[str, str]): Dictionary mapping regex patterns to replacement patterns
param_name (str): The parameter name to be transformed
Returns:
Callable[[str], str]: A function that maps parameter names from source to target format
"""
def mapping_fn(name: str) -> tuple[str, Any, Any]:
# Try to match and transform the name using the regex patterns in mapping_dict
for pattern, replacement in mapping_dict.items():
match = re.match(pattern, name)
if match:
merge_index = None
total_splitted_params = None
if isinstance(replacement, tuple):
merge_index = replacement[1]
total_splitted_params = replacement[2]
replacement = replacement[0]
name = re.sub(pattern, replacement, name)
return name, merge_index, total_splitted_params
# If no pattern matches, return the original name
return name, None, None
return mapping_fn
def hf_to_custom_state_dict(
hf_param_sd: dict[str, torch.Tensor] | Iterator[tuple[str, torch.Tensor]],
param_names_mapping: Callable[[str], tuple[str, Any, Any]],
) -> tuple[dict[str, torch.Tensor], dict[str, tuple[str, Any, Any]]]:
"""
Converts a Hugging Face parameter state dictionary to a custom parameter state dictionary.
Args:
hf_param_sd (Dict[str, torch.Tensor]): The Hugging Face parameter state dictionary
param_names_mapping (Callable[[str], tuple[str, Any, Any]]): A function that maps parameter names from source to target format
Returns:
custom_param_sd (Dict[str, torch.Tensor]): The custom formatted parameter state dict
reverse_param_names_mapping (Dict[str, Tuple[str, Any, Any]]): Maps back from custom to hf
"""
custom_param_sd = {}
to_merge_params = defaultdict(dict) # type: ignore
reverse_param_names_mapping = {}
if isinstance(hf_param_sd, dict):
hf_param_sd = hf_param_sd.items() # type: ignore
for source_param_name, full_tensor in hf_param_sd: # type: ignore
target_param_name, merge_index, num_params_to_merge = param_names_mapping(
source_param_name
)
reverse_param_names_mapping[target_param_name] = (
source_param_name,
merge_index,
num_params_to_merge,
)
if merge_index is not None:
to_merge_params[target_param_name][merge_index] = full_tensor
if len(to_merge_params[target_param_name]) == num_params_to_merge:
# cat at output dim according to the merge_index order
sorted_tensors = [
to_merge_params[target_param_name][i]
for i in range(num_params_to_merge)
]
full_tensor = torch.cat(sorted_tensors, dim=0)
del to_merge_params[target_param_name]
else:
continue
custom_param_sd[target_param_name] = full_tensor
return custom_param_sd, reverse_param_names_mapping
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/model_loader/weight_utils.py
"""Utilities for downloading and initializing model weights."""
import hashlib
import json
import os
import tempfile
from collections.abc import Generator
from pathlib import Path
import filelock
import huggingface_hub.constants
import torch
from safetensors.torch import safe_open
from tqdm.auto import tqdm
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir = tempfile.gettempdir()
def enable_hf_transfer() -> None:
"""automatically activates hf_transfer"""
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
try:
# enable hf hub transfer if available
import hf_transfer # type: ignore # noqa
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
except ImportError:
pass
enable_hf_transfer()
class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
def get_lock(model_name_or_path: str | Path, cache_dir: str | None = None):
lock_dir = cache_dir or temp_dir
model_name_or_path = str(model_name_or_path)
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
model_name = model_name_or_path.replace("/", "-")
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
# add hash to avoid conflict with old users' lock files
lock_file_name = hash_name + model_name + ".lock"
# mode 0o666 is required for the filelock to be shared across users
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
return lock
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the index_file to
# look up which safetensors files should be used.
def filter_duplicate_safetensors_files(
hf_weights_files: list[str], hf_folder: str, index_file: str
) -> list[str]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name = os.path.join(hf_folder, index_file)
if not os.path.isfile(index_file_name):
return hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with open(index_file_name) as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name]))
# Filter out any fields that are not found in the index file.
hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index]
return hf_weights_files
def filter_files_not_needed_for_inference(hf_weights_files: list[str]) -> list[str]:
"""
Exclude files that are not needed for inference.
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
"""
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist)
]
return hf_weights_files
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
def safetensors_weights_iterator(
hf_weights_files: list[str],
to_cpu: bool = True,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
device = "cpu" if to_cpu else str(get_local_torch_device())
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt", device=device) as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
def pt_weights_iterator(
hf_weights_files: list[str],
to_cpu: bool = True,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
device = "cpu" if to_cpu else str(get_local_torch_device())
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location=device, weights_only=True)
yield from state.items()
del state
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
try:
if param.numel() == 1 and loaded_weight.numel() == 1:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
# "broadcast" instead of copy
param.data.fill_(loaded_weight.item())
else:
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) "
f"into parameter ({param.size()})"
)
param.data.copy_(loaded_weight)
except Exception:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
"""Remap the name of FP8 k/v_scale parameters.
This function handles the remapping of FP8 k/v_scale parameter names.
It detects if the given name ends with a suffix and attempts to remap
it to the expected name format in the model. If the remapped name is not
found in the params_dict, a warning is printed and None is returned.
Args:
name (str): The original loaded checkpoint parameter name.
params_dict (dict): Dictionary containing the model's named parameters.
Returns:
str: The remapped parameter name if successful, or the original name
if no remapping is needed.
None: If the remapped name is not found in params_dict.
"""
if name.endswith(".kv_scale"):
logger.warning_once(
"DEPRECATED. Found kv_scale in the checkpoint. "
"This format is deprecated in favor of separate k_scale and "
"v_scale tensors and will be removed in a future release. "
"Functionally, we will remap kv_scale to k_scale and duplicate "
"k_scale to v_scale"
)
# NOTE: we remap the deprecated kv_scale to k_scale
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
if remapped_name not in params_dict:
logger.warning_once(
f"Found kv_scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). kv_scale is "
"not loaded."
)
return None
return remapped_name
possible_scale_names = [".k_scale", ".v_scale"]
modelopt_scale_names = [".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"]
for scale_name in possible_scale_names:
if name.endswith(scale_name):
if any(mo_scale_name in name for mo_scale_name in modelopt_scale_names):
remapped_name = name.replace(
f".self_attn.{scale_name[1]}_proj{scale_name}",
f".self_attn.attn{scale_name}",
)
else:
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict:
logger.warning_once(
f"Found {scale_name} in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_name}). {scale_name} is "
"not loaded."
)
return None
return remapped_name
# If there were no matches, return the untouched param name
return name
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/forward_context.py
import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Type
import torch
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
if TYPE_CHECKING:
from sglang.multimodal_gen.runtime.layers.attention import AttentionMetadata
from sglang.multimodal_gen.runtime.pipelines import Req
logger = init_logger(__name__)
# TODO(will): check if this is needed
# track_batchsize: bool = envs.SGL_DIFFUSION_LOG_BATCHSIZE_INTERVAL >= 0
track_batchsize: bool = False
last_logging_time: float = 0
forward_start_time: float = 0
# batchsize_logging_interval: float = envs.SGL_DIFFUSION_LOG_BATCHSIZE_INTERVAL
batchsize_logging_interval: float = 1000
batchsize_forward_time: defaultdict = defaultdict(list)
@dataclass
class ForwardContext:
current_timestep: int
# TODO(will): check this arg
# copy from vllm_config.compilation_config.static_forward_context
# attn_layers: Dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
forward_batch: Optional["Req"] = None
attention_backend_cls: Optional[Type] = None
def set_attn_backend_cls(self, attention_backend_cls: Type):
if self.attention_backend_cls:
if self.attention_backend_cls != attention_backend_cls:
raise RuntimeError(
f"Different types of attention backend in a same context detected, previous: {self.attention_backend_cls}, new: {attention_backend_cls}"
)
else:
self.attention_backend_cls = attention_backend_cls
_forward_context: Optional["ForwardContext"] = None
def get_forward_context() -> "ForwardContext":
"""Get the current forward context."""
assert _forward_context is not None, (
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context."
)
return _forward_context
# TODO(will): finalize the interface
@contextmanager
def set_forward_context(
current_timestep, attn_metadata, forward_batch: Optional["Req"] = None
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global forward_start_time
need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
current_timestep=current_timestep,
attn_metadata=attn_metadata,
forward_batch=forward_batch,
)
try:
yield
finally:
global last_logging_time, batchsize_logging_interval
if need_to_track_batchsize:
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = (
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
)
else:
# for v1 attention backends
batchsize = attn_metadata.num_input_tokens
now = time.perf_counter()
# time measurement is in milliseconds
batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
if now - last_logging_time > batchsize_logging_interval:
last_logging_time = now
forward_stats = []
for bs, times in batchsize_forward_time.items():
if len(times) <= 1:
# can be cudagraph / profiling run
continue
medium = torch.quantile(torch.tensor(times), q=0.5).item()
medium = round(medium, 2)
forward_stats.append((bs, len(times), medium))
forward_stats.sort(key=lambda x: x[1], reverse=True)
if forward_stats:
logger.info(
(
"Batchsize forward time stats "
"(batchsize, count, median_time(ms)): %s"
),
forward_stats,
)
_forward_context = prev_context
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import multiprocessing as mp
import os
from typing import List
import torch
from setproctitle import setproctitle
from sglang.multimodal_gen.runtime.distributed import (
get_sp_group,
maybe_init_distributed_environment_and_model_parallel,
)
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_cfg_group,
get_tp_group,
)
from sglang.multimodal_gen.runtime.pipelines import build_pipeline
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import OutputBatch, Req
from sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs
from sglang.multimodal_gen.runtime.utils.common import set_cuda_arch
from sglang.multimodal_gen.runtime.utils.logging_utils import (
configure_logger,
init_logger,
suppress_other_loggers,
)
logger = init_logger(__name__)
# ANSI color codes
CYAN = "\033[1;36m"
RESET = "\033[0;0m"
class GPUWorker:
"""
A worker that executes the model on a single GPU.
"""
def __init__(
self,
local_rank: int,
rank: int,
master_port: int,
server_args: ServerArgs,
):
self.local_rank = local_rank
self.rank = rank
self.master_port = master_port
# FIXME: should we use tcp as distribute init method?
self.server_args = server_args
self.pipeline = None
self.init_device_and_model()
self.sp_group = get_sp_group()
self.sp_cpu_group = self.sp_group.cpu_group
self.tp_group = get_tp_group()
self.tp_cpu_group = self.tp_group.cpu_group
self.cfg_group = get_cfg_group()
self.cfg_cpu_group = self.cfg_group.cpu_group
def init_device_and_model(self) -> None:
"""Initialize the device and load the model."""
setproctitle(f"sgl_diffusion::scheduler:{self.local_rank}")
torch.cuda.set_device(self.local_rank)
# Set environment variables for distributed initialization
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(self.master_port)
os.environ["LOCAL_RANK"] = str(self.local_rank)
os.environ["RANK"] = str(self.rank)
os.environ["WORLD_SIZE"] = str(self.server_args.num_gpus)
# Initialize the distributed environment
maybe_init_distributed_environment_and_model_parallel(
tp_size=self.server_args.tp_size,
enable_cfg_parallel=self.server_args.enable_cfg_parallel,
ulysses_degree=self.server_args.ulysses_degree,
ring_degree=self.server_args.ring_degree,
sp_size=self.server_args.sp_degree,
dp_size=self.server_args.dp_size,
)
self.pipeline = build_pipeline(self.server_args)
logger.info(
f"Worker {self.rank}: Initialized device, model, and distributed environment."
)
def execute_forward(self, batch: List[Req], server_args: ServerArgs) -> OutputBatch:
"""
Execute a forward pass.
"""
assert self.pipeline is not None
# TODO: dealing with first req for now
req = batch[0]
output_batch = self.pipeline.forward(req, server_args)
if req.perf_logger:
req.perf_logger.log_total_duration("total_inference_time")
return output_batch
def set_lora_adapter(
self, lora_nickname: str, lora_path: str | None = None
) -> None:
"""
Set the LoRA adapter for the pipeline.
"""
assert self.pipeline is not None
self.pipeline.set_lora_adapter(lora_nickname, lora_path)
def merge_lora_weights(self) -> None:
"""
Merge LoRA weights.
"""
assert self.pipeline is not None
self.pipeline.merge_lora_weights()
def unmerge_lora_weights(self) -> None:
"""
Unmerge LoRA weights.
"""
assert self.pipeline is not None
self.pipeline.unmerge_lora_weights()
def run_scheduler_process(
local_rank: int,
rank: int,
master_port: int,
server_args: ServerArgs,
pipe_writer: mp.connection.Connection,
# For all workers: pipe to receive tasks from rank 0
task_pipe_r: mp.connection.Connection,
# For slave workers: pipe to send results back to rank 0
result_pipe_w: mp.connection.Connection | None,
# For rank 0 worker only: pipes to send tasks to slaves
task_pipes_to_slaves: list[mp.connection.Connection] | None = None,
# For rank 0 worker only: pipes to receive results from slaves
result_pipes_from_slaves: list[mp.connection.Connection] | None = None,
) -> None:
"""
The entry point for the worker process.
Rank 0 acts as the master, handling ZMQ requests and coordinating slaves.
Ranks > 0 act as slaves, waiting for tasks from the master.
"""
configure_logger(server_args)
suppress_other_loggers()
set_cuda_arch()
port_args = PortArgs.from_server_args(server_args)
# start the scheduler event loop
assert task_pipes_to_slaves is not None
assert result_pipes_from_slaves is not None
from sglang.multimodal_gen.runtime.managers.scheduler import Scheduler
scheduler = Scheduler(
server_args,
gpu_id=rank,
port_args=port_args,
task_pipes_to_slaves=task_pipes_to_slaves,
result_pipes_from_slaves=result_pipes_from_slaves,
)
logger.info(f"Worker {rank}: Scheduler loop started.")
pipe_writer.send(
{
"status": "ready",
}
)
scheduler.event_loop()
logger.info(f"Worker {rank}: Shutdown complete.")
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from typing import Any
import zmq
from sglang.multimodal_gen.runtime.managers.gpu_worker import GPUWorker
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import OutputBatch
from sglang.multimodal_gen.runtime.server_args import (
PortArgs,
ServerArgs,
set_global_server_args,
)
from sglang.multimodal_gen.runtime.utils.common import get_zmq_socket
from sglang.multimodal_gen.runtime.utils.distributed import broadcast_pyobj
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class Scheduler:
"""
Runs the main event loop for the rank 0 worker.
It listens for external requests via ZMQ and coordinates with other workers.
This class does NOT manage worker processes.
"""
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
port_args: PortArgs,
task_pipes_to_slaves: list = None,
result_pipes_from_slaves: list = None,
):
self.server_args = server_args
self.port_args = port_args
set_global_server_args(server_args=server_args)
# Inter-process Communication
self.context = zmq.Context(io_threads=2)
endpoint = server_args.scheduler_endpoint()
logger.info(f"Scheduler listening at endpoint: {endpoint}")
if gpu_id == 0:
self.receiver = get_zmq_socket(self.context, zmq.REP, endpoint, True)
else:
self.receiver = None
worker = GPUWorker(
local_rank=gpu_id,
master_port=port_args.master_port,
rank=gpu_id,
server_args=server_args,
)
self.worker = worker
self.task_pipes_to_slaves = task_pipes_to_slaves
self.result_pipes_from_slaves = result_pipes_from_slaves
self.gpu_id = gpu_id
self._running = True
def return_result(self, output_batch: OutputBatch):
"""
replies to client, only on rank 0
"""
if self.receiver is not None:
self.receiver.send_pyobj(output_batch)
def recv_reqs(self):
"""
For non-main schedulers, reqs are broadcasted from main using broadcast_pyobj
"""
if self.receiver is not None:
recv_reqs = self.receiver.recv_pyobj()
assert isinstance(recv_reqs, list)
else:
recv_reqs = None
# TODO: fix this condition
if self.server_args.sp_degree != 1:
recv_reqs = broadcast_pyobj(
recv_reqs,
self.worker.sp_group.rank,
self.worker.sp_cpu_group,
src=self.worker.sp_group.ranks[0],
)
if self.server_args.enable_cfg_parallel:
recv_reqs = broadcast_pyobj(
recv_reqs,
self.worker.cfg_group.rank,
self.worker.cfg_cpu_group,
src=self.worker.cfg_group.ranks[0],
)
if self.server_args.tp_size > 1:
recv_reqs = broadcast_pyobj(
recv_reqs,
self.worker.tp_group.rank,
self.worker.tp_cpu_group,
src=self.worker.tp_group.ranks[0],
)
assert recv_reqs is not None
return recv_reqs
# TODO: queueing, cancellation
def event_loop(self) -> None:
"""
The main event loop that listens for ZMQ requests.
Handles abortion
"""
logger.info(
f"Rank 0 scheduler listening on tcp://*:{self.server_args.scheduler_port}"
)
while self._running:
reqs = None
# 1: receive requests
try:
reqs = self.recv_reqs()
except Exception as e:
logger.error(
f"Error receiving requests in scheduler event loop: {e}",
exc_info=True,
)
continue
# 2: execute, make sure a reply is always sent
try:
output_batch = self.worker.execute_forward(reqs, self.server_args)
except Exception as e:
logger.error(
f"Error executing forward in scheduler event loop: {e}",
exc_info=True,
)
output_batch = OutputBatch(error=str(e))
try:
self.return_result(output_batch)
except zmq.ZMQError as e:
# Reply failed; log and keep loop alive to accept future requests
logger.error(f"ZMQ error sending reply: {e}")
continue
logger.info("Scheduler event loop terminated.")
if self.receiver is not None:
self.receiver.close()
self.context.term()
def _broadcast_task(self, payload: dict[str, Any]) -> None:
"""Broadcast a task to all slave worker processes."""
method = payload["method"]
kwargs = {k: v for k, v in payload.items() if k != "method"}
task = {"method": method, "kwargs": kwargs}
for pipe in self.task_pipes_to_slaves:
pipe.send(task)
def _execute_on_rank0(self, payload: dict[str, Any]) -> dict[str, Any]:
"""Execute task locally on the rank 0 worker."""
method = payload["method"]
kwargs = {k: v for k, v in payload.items() if k != "method"}
handler = getattr(self.worker, method, None)
if handler:
result = handler(**kwargs)
return {"status": "ok", "result": result}
return {"status": "error", "error": f"Unknown method: {method}"}
def _collect_slave_results(self) -> list[dict[str, Any]]:
"""Collect results from all slave worker processes."""
results = []
for pipe in self.result_pipes_from_slaves:
results.append(pipe.recv())
return results
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from abc import ABC
from typing import TypeVar
import zmq
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import OutputBatch, Req
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.utils import init_logger
logger = init_logger(__name__)
_R = TypeVar("_R")
class SchedulerBase(ABC):
"""
Abstract base class for all schedulers.
"""
def __init__(self, server_args: "ServerArgs"):
"""
Initialize the scheduler.
Args:
server_args: The inference arguments
"""
self.server_args = server_args
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(self.server_args.scheduler_endpoint())
@classmethod
def get_class(cls, server_args: "ServerArgs") -> type["SchedulerBase"]:
"""
Get the scheduler class based on the server arguments.
"""
if server_args.distributed_executor_backend == "mp":
from sglang.multimodal_gen.runtime.managers.scheduler import Scheduler
# For now, always return the new Scheduler
return Scheduler
else:
raise ValueError(
f"Unsupported distributed executor backend: {server_args.distributed_executor_backend}"
)
# @abstractmethod
def start(self) -> None:
"""
Start the scheduler service.
"""
raise NotImplementedError
def execute_forward(self, batch: Req, server_args: "ServerArgs") -> OutputBatch:
"""
Execute a forward pass. This method now sends a request over ZMQ.
"""
payload = {"method": "execute_forward", "batch": batch}
self.socket.send_pyobj(payload)
output_batch = self.socket.recv_pyobj()
return output_batch
def set_lora_adapter(
self, lora_nickname: str, lora_path: str | None = None
) -> None:
"""
Set the LoRA adapter.
"""
payload = {
"method": "set_lora_adapter",
"lora_nickname": lora_nickname,
"lora_path": lora_path,
}
self.socket.send_pyobj(payload)
self.socket.recv_pyobj() # Wait for confirmation
# @abstractmethod
def unmerge_lora_weights(self) -> None:
"""
Unmerge the LoRA weights for the workers.
"""
raise NotImplementedError
# @abstractmethod
def merge_lora_weights(self) -> None:
"""
Merge the LoRA weights for the workers.
"""
raise NotImplementedError
def shutdown(self) -> None:
"""
Shutdown the scheduler.
"""
logger.info("Shutting down scheduler client.")
payload = {"method": "shutdown"}
self.socket.send_pyobj(payload)
self.socket.recv_pyobj() # Wait for shutdown confirmation
self.socket.close()
self.context.term()
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch import nn
from sglang.multimodal_gen.configs.models import DiTConfig
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
# TODO
class BaseDiT(nn.Module, ABC):
_fsdp_shard_conditions: list = []
_compile_conditions: list = []
param_names_mapping: dict
reverse_param_names_mapping: dict
hidden_size: int
num_attention_heads: int
num_channels_latents: int
# always supports torch_sdpa
_supported_attention_backends: set[AttentionBackendEnum] = (
DiTConfig()._supported_attention_backends
)
def __init_subclass__(cls) -> None:
required_class_attrs = [
"_fsdp_shard_conditions",
"param_names_mapping",
"_compile_conditions",
]
super().__init_subclass__()
for attr in required_class_attrs:
if not hasattr(cls, attr):
raise AttributeError(
f"Subclasses of BaseDiT must define '{attr}' class variable"
)
def __init__(self, config: DiTConfig, hf_config: dict[str, Any], **kwargs) -> None:
super().__init__()
self.config = config
self.hf_config = hf_config
if not self.supported_attention_backends:
raise ValueError(
f"Subclass {self.__class__.__name__} must define _supported_attention_backends"
)
@abstractmethod
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | list[torch.Tensor],
timestep: torch.LongTensor,
encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None,
guidance=None,
**kwargs,
) -> torch.Tensor:
pass
def __post_init__(self) -> None:
required_attrs = ["hidden_size", "num_attention_heads", "num_channels_latents"]
for attr in required_attrs:
if not hasattr(self, attr):
raise AttributeError(
f"Subclasses of BaseDiT must define '{attr}' instance variable"
)
@property
def supported_attention_backends(self) -> set[AttentionBackendEnum]:
return self._supported_attention_backends
@property
def device(self) -> torch.device:
"""Get the device of the model."""
return next(self.parameters()).device
class CachableDiT(BaseDiT):
"""
An intermediate base class that adds TeaCache optimization functionality to DiT models.
TeaCache accelerates inference by selectively skipping redundant computation when consecutive
diffusion steps are similar enough.
"""
# These are required class attributes that should be overridden by concrete implementations
_fsdp_shard_conditions = []
param_names_mapping = {}
reverse_param_names_mapping = {}
lora_param_names_mapping: dict = {}
# Ensure these instance attributes are properly defined in subclasses
hidden_size: int
num_attention_heads: int
num_channels_latents: int
# always supports torch_sdpa
_supported_attention_backends: set[AttentionBackendEnum] = (
DiTConfig()._supported_attention_backends
)
def __init__(self, config: DiTConfig, **kwargs) -> None:
super().__init__(config, **kwargs)
self.cnt = 0
self.teacache_thresh = 0
self.coefficients: list[float] = []
# NOTE(will): Only wan2.1 needs these, so we are hardcoding it here
if self.config.prefix == "wan":
self.use_ret_steps = self.config.cache_config.use_ret_steps
self.is_even = False
self.previous_residual_even: torch.Tensor | None = None
self.previous_residual_odd: torch.Tensor | None = None
self.accumulated_rel_l1_distance_even = 0
self.accumulated_rel_l1_distance_odd = 0
self.should_calc_even = True
self.should_calc_odd = True
else:
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.previous_resiual = None
self.previous_e0_even: torch.Tensor | None = None
self.previous_e0_odd: torch.Tensor | None = None
def maybe_cache_states(
self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor
) -> None:
pass
def should_skip_forward_for_cached_states(self, **kwargs: dict[str, Any]) -> bool:
return False
def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("maybe_retrieve_cached_states is not implemented")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment