Unverified Commit 7bc1dae0 authored by Mick's avatar Mick Committed by GitHub
Browse files

WIP: initial multimodal-gen support (#12484)


Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarJiLi <leege233@gmail.com>
Co-authored-by: default avatarCHEN Xi <78632976+RubiaCx@users.noreply.github.com>
Co-authored-by: default avatarlaixin <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarSolitaryThinker <wlsaidhi@gmail.com>
Co-authored-by: default avatarjzhang38 <a1286225768@gmail.com>
Co-authored-by: default avatarBrianChen1129 <yongqichcd@gmail.com>
Co-authored-by: default avatarKevin Lin <42618777+kevin314@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarrlsu9 <r3su@ucsd.edu>
Co-authored-by: default avatarJinzhe Pan <48981407+eigensystem@users.noreply.github.com>
Co-authored-by: default avatarforeverpiano <pianoqwz@qq.com>
Co-authored-by: default avatarRandNMR73 <notomatthew31@gmail.com>
Co-authored-by: default avatarPorridgeSwim <yz3883@columbia.edu>
Co-authored-by: default avatarJiali Chen <90408393+gary-chenjl@users.noreply.github.com>
parent 4fe53e58
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Any
import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import (
BlockMask,
create_block_mask,
flex_attention,
)
# wan 1.3B model has a weird channel / head configurations and require max-autotune to work with flexattention
# see https://github.com/pytorch/pytorch/issues/133254
# change to default for other models
flex_attention = torch.compile(
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
)
import torch.distributed as dist
from sglang.multimodal_gen.configs.models.dits import WanVideoConfig
from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size
from sglang.multimodal_gen.runtime.layers.attention import LocalAttention
from sglang.multimodal_gen.runtime.layers.layernorm import (
FP32LayerNorm,
LayerNormScaleShift,
RMSNorm,
ScaleResidual,
ScaleResidualLayerNormScaleShift,
)
from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear
from sglang.multimodal_gen.runtime.layers.mlp import MLP
from sglang.multimodal_gen.runtime.layers.rotary_embedding import (
_apply_rotary_emb,
get_rotary_pos_embed,
)
from sglang.multimodal_gen.runtime.layers.visual_embedding import PatchEmbed
from sglang.multimodal_gen.runtime.models.dits.base import BaseDiT
from sglang.multimodal_gen.runtime.models.dits.wanvideo import (
WanT2VCrossAttention,
WanTimeTextImageEmbedding,
)
from sglang.multimodal_gen.runtime.platforms import (
AttentionBackendEnum,
current_platform,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class CausalWanSelfAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
local_attn_size: int = -1,
sink_size: int = 0,
qk_norm=True,
eps=1e-6,
parallel_attention=False,
) -> None:
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.local_attn_size = local_attn_size
self.sink_size = sink_size
self.qk_norm = qk_norm
self.eps = eps
self.parallel_attention = parallel_attention
self.max_attention_size = (
32760 if local_attn_size == -1 else local_attn_size * 1560
)
# Scaled dot product attention
self.attn = LocalAttention(
num_heads=num_heads,
head_size=self.head_dim,
dropout_rate=0,
softmax_scale=None,
causal=False,
supported_attention_backends=(
AttentionBackendEnum.FA3,
AttentionBackendEnum.TORCH_SDPA,
),
)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
freqs_cis: tuple[torch.Tensor, torch.Tensor],
block_mask: BlockMask,
kv_cache: dict | None = None,
current_start: int = 0,
cache_start: int | None = None,
):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
if cache_start is None:
cache_start = current_start
cos, sin = freqs_cis
roped_query = _apply_rotary_emb(q, cos, sin, is_neox_style=False).type_as(v)
roped_key = _apply_rotary_emb(k, cos, sin, is_neox_style=False).type_as(v)
if kv_cache is None:
# Padding for flex attention
padded_length = math.ceil(q.shape[1] / 128) * 128 - q.shape[1]
padded_roped_query = torch.cat(
[
roped_query,
torch.zeros(
[q.shape[0], padded_length, q.shape[2], q.shape[3]],
device=q.device,
dtype=v.dtype,
),
],
dim=1,
)
padded_roped_key = torch.cat(
[
roped_key,
torch.zeros(
[k.shape[0], padded_length, k.shape[2], k.shape[3]],
device=k.device,
dtype=v.dtype,
),
],
dim=1,
)
padded_v = torch.cat(
[
v,
torch.zeros(
[v.shape[0], padded_length, v.shape[2], v.shape[3]],
device=v.device,
dtype=v.dtype,
),
],
dim=1,
)
x = flex_attention(
query=padded_roped_query.transpose(2, 1),
key=padded_roped_key.transpose(2, 1),
value=padded_v.transpose(2, 1),
block_mask=block_mask,
)[:, :, :-padded_length].transpose(2, 1)
else:
frame_seqlen = q.shape[1]
current_end = current_start + roped_query.shape[1]
sink_tokens = self.sink_size * frame_seqlen
# If we are using local attention and the current KV cache size is larger than the local attention size, we need to truncate the KV cache
kv_cache_size = kv_cache["k"].shape[1]
num_new_tokens = roped_query.shape[1]
if (
self.local_attn_size != -1
and (current_end > kv_cache["global_end_index"].item())
and (
num_new_tokens + kv_cache["local_end_index"].item() > kv_cache_size
)
):
# Calculate the number of new tokens added in this step
# Shift existing cache content left to discard oldest tokens
# Clone the source slice to avoid overlapping memory error
num_evicted_tokens = (
num_new_tokens + kv_cache["local_end_index"].item() - kv_cache_size
)
num_rolled_tokens = (
kv_cache["local_end_index"].item()
- num_evicted_tokens
- sink_tokens
)
kv_cache["k"][
:, sink_tokens : sink_tokens + num_rolled_tokens
] = kv_cache["k"][
:,
sink_tokens
+ num_evicted_tokens : sink_tokens
+ num_evicted_tokens
+ num_rolled_tokens,
].clone()
kv_cache["v"][
:, sink_tokens : sink_tokens + num_rolled_tokens
] = kv_cache["v"][
:,
sink_tokens
+ num_evicted_tokens : sink_tokens
+ num_evicted_tokens
+ num_rolled_tokens,
].clone()
# Insert the new keys/values at the end
local_end_index = (
kv_cache["local_end_index"].item()
+ current_end
- kv_cache["global_end_index"].item()
- num_evicted_tokens
)
local_start_index = local_end_index - num_new_tokens
kv_cache["k"][:, local_start_index:local_end_index] = roped_key
kv_cache["v"][:, local_start_index:local_end_index] = v
else:
# Assign new keys/values directly up to current_end
local_end_index = (
kv_cache["local_end_index"].item()
+ current_end
- kv_cache["global_end_index"].item()
)
local_start_index = local_end_index - num_new_tokens
kv_cache["k"] = kv_cache["k"].detach()
kv_cache["v"] = kv_cache["v"].detach()
# logger.info("kv_cache['k'] is in comp graph: %s", kv_cache["k"].requires_grad or kv_cache["k"].grad_fn is not None)
kv_cache["k"][:, local_start_index:local_end_index] = roped_key
kv_cache["v"][:, local_start_index:local_end_index] = v
x = self.attn(
roped_query,
kv_cache["k"][
:,
max(0, local_end_index - self.max_attention_size) : local_end_index,
],
kv_cache["v"][
:,
max(0, local_end_index - self.max_attention_size) : local_end_index,
],
)
kv_cache["global_end_index"].fill_(current_end)
kv_cache["local_end_index"].fill_(local_end_index)
return x
class CausalWanTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
local_attn_size: int = -1,
sink_size: int = 0,
qk_norm: str = "rms_norm_across_heads",
cross_attn_norm: bool = False,
eps: float = 1e-6,
added_kv_proj_dim: int | None = None,
supported_attention_backends: set[AttentionBackendEnum] | None = None,
prefix: str = "",
):
super().__init__()
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.to_q = ReplicatedLinear(dim, dim, bias=True)
self.to_k = ReplicatedLinear(dim, dim, bias=True)
self.to_v = ReplicatedLinear(dim, dim, bias=True)
self.to_out = ReplicatedLinear(dim, dim, bias=True)
self.attn1 = CausalWanSelfAttention(
dim,
num_heads,
local_attn_size=local_attn_size,
sink_size=sink_size,
qk_norm=qk_norm,
eps=eps,
)
self.hidden_dim = dim
self.num_attention_heads = num_heads
self.local_attn_size = local_attn_size
dim_head = dim // num_heads
if qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
elif qk_norm == "rms_norm_across_heads":
# LTX applies qk norm across all heads
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
else:
print("QK Norm type not supported")
raise Exception
assert cross_attn_norm is True
self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=True,
dtype=torch.float32,
compute_dtype=torch.float32,
)
# 2. Cross-attention
# Only T2V for now
self.attn2 = WanT2VCrossAttention(dim, num_heads, qk_norm=qk_norm, eps=eps)
self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=False,
dtype=torch.float32,
compute_dtype=torch.float32,
)
# 3. Feed-forward
self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh")
self.mlp_residual = ScaleResidual()
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
freqs_cis: tuple[torch.Tensor, torch.Tensor],
block_mask: BlockMask,
kv_cache: dict | None = None,
crossattn_cache: dict | None = None,
current_start: int = 0,
cache_start: int | None = None,
) -> torch.Tensor:
# hidden_states.shape: [batch_size, seq_length, inner_dim]
# temb.shape: [batch_size, num_frames, 6, inner_dim]
if hidden_states.dim() == 4:
hidden_states = hidden_states.squeeze(1)
num_frames = temb.shape[1]
frame_seqlen = hidden_states.shape[1] // num_frames
bs, seq_length, _ = hidden_states.shape
orig_dtype = hidden_states.dtype
# assert orig_dtype != torch.float32
e = self.scale_shift_table + temb.float()
# e.shape: [batch_size, num_frames, 6, inner_dim]
assert e.shape == (bs, num_frames, 6, self.hidden_dim)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk(
6, dim=2
)
# *_msa.shape: [batch_size, num_frames, 1, inner_dim]
assert shift_msa.dtype == torch.float32
# 1. Self-attention
norm_hidden_states = (
(
self.norm1(hidden_states.float()).unflatten(
dim=1, sizes=(num_frames, frame_seqlen)
)
* (1 + scale_msa)
+ shift_msa
)
.flatten(1, 2)
.to(orig_dtype)
)
query, _ = self.to_q(norm_hidden_states)
key, _ = self.to_k(norm_hidden_states)
value, _ = self.to_v(norm_hidden_states)
if self.norm_q is not None:
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k(key)
query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
attn_output = self.attn1(
query,
key,
value,
freqs_cis,
block_mask,
kv_cache,
current_start,
cache_start,
)
attn_output = attn_output.flatten(2)
attn_output, _ = self.to_out(attn_output)
attn_output = attn_output.squeeze(1)
null_shift = null_scale = torch.zeroes(
(1,), device=hidden_states.device, dtype=hidden_states.dtype
)
norm_hidden_states, hidden_states = self.self_attn_residual_norm(
hidden_states, attn_output, gate_msa, null_shift, null_scale
)
norm_hidden_states, hidden_states = norm_hidden_states.to(
orig_dtype
), hidden_states.to(orig_dtype)
# 2. Cross-attention
attn_output = self.attn2(
norm_hidden_states,
context=encoder_hidden_states,
context_lens=None,
crossattn_cache=crossattn_cache,
)
norm_hidden_states, hidden_states = self.cross_attn_residual_norm(
hidden_states, attn_output, 1, c_shift_msa, c_scale_msa
)
norm_hidden_states, hidden_states = norm_hidden_states.to(
orig_dtype
), hidden_states.to(orig_dtype)
# 3. Feed-forward
ff_output = self.ffn(norm_hidden_states)
hidden_states = self.mlp_residual(hidden_states, ff_output, c_gate_msa)
hidden_states = hidden_states.to(orig_dtype)
return hidden_states
class CausalWanTransformer3DModel(BaseDiT):
_fsdp_shard_conditions = WanVideoConfig()._fsdp_shard_conditions
_compile_conditions = WanVideoConfig()._compile_conditions
_supported_attention_backends = WanVideoConfig()._supported_attention_backends
param_names_mapping = WanVideoConfig().param_names_mapping
reverse_param_names_mapping = WanVideoConfig().reverse_param_names_mapping
lora_param_names_mapping = WanVideoConfig().lora_param_names_mapping
def __init__(self, config: WanVideoConfig, hf_config: dict[str, Any]) -> None:
super().__init__(config=config, hf_config=hf_config)
inner_dim = config.num_attention_heads * config.attention_head_dim
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.attention_head_dim = config.attention_head_dim
self.in_channels = config.in_channels
self.out_channels = config.out_channels
self.num_channels_latents = config.num_channels_latents
self.patch_size = config.patch_size
self.text_len = config.text_len
self.local_attn_size = config.local_attn_size
# 1. Patch & position embedding
self.patch_embedding = PatchEmbed(
in_chans=config.in_channels,
embed_dim=inner_dim,
patch_size=config.patch_size,
flatten=False,
)
# 2. Condition embeddings
self.condition_embedder = WanTimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=config.freq_dim,
text_embed_dim=config.text_dim,
image_embed_dim=config.image_dim,
)
# 3. Transformer blocks
self.blocks = nn.ModuleList(
[
CausalWanTransformerBlock(
inner_dim,
config.ffn_dim,
config.num_attention_heads,
config.local_attn_size,
config.sink_size,
config.qk_norm,
config.cross_attn_norm,
config.eps,
config.added_kv_proj_dim,
self._supported_attention_backends,
prefix=f"{config.prefix}.blocks.{i}",
)
for i in range(config.num_layers)
]
)
# 4. Output norm & projection
self.norm_out = LayerNormScaleShift(
inner_dim,
norm_type="layer",
eps=config.eps,
elementwise_affine=False,
dtype=torch.float32,
compute_dtype=torch.float32,
)
self.proj_out = nn.Linear(
inner_dim, config.out_channels * math.prod(config.patch_size)
)
self.scale_shift_table = nn.Parameter(
torch.randn(1, 2, inner_dim) / inner_dim**0.5
)
self.gradient_checkpointing = False
# Causal-specific
self.block_mask = None
self.num_frame_per_block = config.arch_config.num_frames_per_block
assert self.num_frame_per_block <= 3
self.independent_first_frame = False
self.__post_init__()
@staticmethod
def _prepare_blockwise_causal_attn_mask(
device: torch.device | str,
num_frames: int = 21,
frame_seqlen: int = 1560,
num_frame_per_block=1,
local_attn_size=-1,
) -> BlockMask:
"""
we will divide the token sequence into the following format
[1 latent frame] [1 latent frame] ... [1 latent frame]
We use flexattention to construct the attention mask
"""
total_length = num_frames * frame_seqlen
# we do right padding to get to a multiple of 128
padded_length = math.ceil(total_length / 128) * 128 - total_length
ends = torch.zeros(
total_length + padded_length, device=device, dtype=torch.long
)
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
frame_indices = torch.arange(
start=0,
end=total_length,
step=frame_seqlen * num_frame_per_block,
device=device,
)
for tmp in frame_indices:
ends[tmp : tmp + frame_seqlen * num_frame_per_block] = (
tmp + frame_seqlen * num_frame_per_block
)
def attention_mask(b, h, q_idx, kv_idx):
if local_attn_size == -1:
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
else:
return (
(kv_idx < ends[q_idx])
& (kv_idx >= (ends[q_idx] - local_attn_size * frame_seqlen))
) | (q_idx == kv_idx)
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
block_mask = create_block_mask(
attention_mask,
B=None,
H=None,
Q_LEN=total_length + padded_length,
KV_LEN=total_length + padded_length,
_compile=False,
device=device,
)
if not dist.is_initialized() or dist.get_rank() == 0:
print(
f" cache a block wise causal mask with block size of {num_frame_per_block} frames"
)
print(block_mask)
# import imageio
# import numpy as np
# from torch.nn.attention.flex_attention import create_mask
# mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length +
# padded_length, KV_LEN=total_length + padded_length, device=device)
# import cv2
# mask = cv2.resize(mask[0, 0].cpu().float().numpy(), (1024, 1024))
# imageio.imwrite("mask_%d.jpg" % (0), np.uint8(255. * mask))
return block_mask
def _forward_inference(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | list[torch.Tensor],
timestep: torch.LongTensor,
encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None,
kv_cache: dict = None,
crossattn_cache: dict = None,
current_start: int = 0,
cache_start: int = 0,
start_frame: int = 0,
**kwargs,
) -> torch.Tensor:
r"""
Run the diffusion model with kv caching.
See Algorithm 2 of CausVid paper https://arxiv.org/abs/2412.07772 for details.
This function will be run for num_frame times.
Process the latent frames one by one (1560 tokens each)
"""
orig_dtype = hidden_states.dtype
if not isinstance(encoder_hidden_states, torch.Tensor):
encoder_hidden_states = encoder_hidden_states[0]
if (
isinstance(encoder_hidden_states_image, list)
and len(encoder_hidden_states_image) > 0
):
encoder_hidden_states_image = encoder_hidden_states_image[0]
else:
encoder_hidden_states_image = None
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
# Get rotary embeddings
d = self.hidden_size // self.num_attention_heads
rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]
freqs_cos, freqs_sin = get_rotary_pos_embed(
(
post_patch_num_frames * get_sp_world_size(),
post_patch_height,
post_patch_width,
),
self.hidden_size,
self.num_attention_heads,
rope_dim_list,
dtype=torch.float32 if current_platform.is_mps() else torch.float64,
rope_theta=10000,
start_frame=start_frame, # Assume that start_frame is 0 when kv_cache is None
)
freqs_cos = freqs_cos.to(hidden_states.device)
freqs_sin = freqs_sin.to(hidden_states.device)
freqs_cis = (
(freqs_cos.float(), freqs_sin.float()) if freqs_cos is not None else None
)
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = (
self.condition_embedder(
timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image
)
)
timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)).unflatten(
dim=0, sizes=timestep.shape
)
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat(
[encoder_hidden_states_image, encoder_hidden_states], dim=1
)
encoder_hidden_states = (
encoder_hidden_states.to(orig_dtype)
if current_platform.is_mps()
else encoder_hidden_states
) # cast to orig_dtype for MPS
assert encoder_hidden_states.dtype == orig_dtype
# 4. Transformer blocks
for block_index, block in enumerate(self.blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
causal_kwargs = {
"kv_cache": kv_cache[block_index],
"current_start": current_start,
"cache_start": cache_start,
"block_mask": self.block_mask,
}
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
timestep_proj,
freqs_cis,
**causal_kwargs,
)
else:
causal_kwargs = {
"kv_cache": kv_cache[block_index],
"crossattn_cache": crossattn_cache[block_index],
"current_start": current_start,
"cache_start": cache_start,
"block_mask": self.block_mask,
}
hidden_states = block(
hidden_states,
encoder_hidden_states,
timestep_proj,
freqs_cis,
**causal_kwargs,
)
# 5. Output norm, projection & unpatchify
temb = temb.unflatten(dim=0, sizes=timestep.shape).unsqueeze(2)
shift, scale = (self.scale_shift_table.unsqueeze(1) + temb).chunk(2, dim=2)
hidden_states = self.norm_out(hidden_states, shift, scale)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size,
post_patch_num_frames,
post_patch_height,
post_patch_width,
p_t,
p_h,
p_w,
-1,
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
return output
def _forward_train(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | list[torch.Tensor],
timestep: torch.LongTensor,
encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None,
start_frame: int = 0,
**kwargs,
) -> torch.Tensor:
orig_dtype = hidden_states.dtype
if not isinstance(encoder_hidden_states, torch.Tensor):
encoder_hidden_states = encoder_hidden_states[0]
if (
isinstance(encoder_hidden_states_image, list)
and len(encoder_hidden_states_image) > 0
):
encoder_hidden_states_image = encoder_hidden_states_image[0]
else:
encoder_hidden_states_image = None
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
# Get rotary embeddings
d = self.hidden_size // self.num_attention_heads
rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]
freqs_cos, freqs_sin = get_rotary_pos_embed(
(
post_patch_num_frames * get_sp_world_size(),
post_patch_height,
post_patch_width,
),
self.hidden_size,
self.num_attention_heads,
rope_dim_list,
dtype=torch.float32 if current_platform.is_mps() else torch.float64,
rope_theta=10000,
start_frame=start_frame,
)
freqs_cos = freqs_cos.to(hidden_states.device)
freqs_sin = freqs_sin.to(hidden_states.device)
freqs_cis = (
(freqs_cos.float(), freqs_sin.float()) if freqs_cos is not None else None
)
# Construct blockwise causal attn mask
if self.block_mask is None:
self.block_mask = self._prepare_blockwise_causal_attn_mask(
device=hidden_states.device,
num_frames=num_frames,
frame_seqlen=post_patch_height * post_patch_width,
num_frame_per_block=self.num_frame_per_block,
local_attn_size=self.local_attn_size,
)
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = (
self.condition_embedder(
timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image
)
)
timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)).unflatten(
dim=0, sizes=timestep.shape
)
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat(
[encoder_hidden_states_image, encoder_hidden_states], dim=1
)
encoder_hidden_states = (
encoder_hidden_states.to(orig_dtype)
if current_platform.is_mps()
else encoder_hidden_states
) # cast to orig_dtype for MPS
assert encoder_hidden_states.dtype == orig_dtype
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
timestep_proj,
freqs_cis,
block_mask=self.block_mask,
)
else:
for block in self.blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states,
timestep_proj,
freqs_cis,
block_mask=self.block_mask,
)
# 5. Output norm, projection & unpatchify
temb = temb.unflatten(dim=0, sizes=timestep.shape).unsqueeze(2)
shift, scale = (self.scale_shift_table.unsqueeze(1) + temb).chunk(2, dim=2)
hidden_states = self.norm_out(hidden_states, shift, scale)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size,
post_patch_num_frames,
post_patch_height,
post_patch_width,
p_t,
p_h,
p_w,
-1,
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
return output
def forward(self, *args, **kwargs):
if kwargs.get("kv_cache") is not None:
return self._forward_inference(*args, **kwargs)
else:
return self._forward_train(*args, **kwargs)
EntryClass = CausalWanTransformer3DModel
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from diffusers.models.attention import AttentionModuleMixin, FeedForward
from diffusers.models.embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
)
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.normalization import (
AdaLayerNormContinuous,
AdaLayerNormZero,
AdaLayerNormZeroSingle,
)
from torch.nn import LayerNorm as LayerNorm
from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig
from sglang.multimodal_gen.runtime.layers.attention import LocalAttention
# from sglang.multimodal_gen.runtime.layers.layernorm import LayerNorm as LayerNorm
from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm
from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear
from sglang.multimodal_gen.runtime.layers.mlp import MLP
from sglang.multimodal_gen.runtime.layers.rotary_embedding import (
NDRotaryEmbedding,
_apply_rotary_emb,
)
from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
from sglang.multimodal_gen.runtime.platforms import (
AttentionBackendEnum,
current_platform,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__) # pylint: disable=invalid-name
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
query, _ = attn.to_q(hidden_states)
key, _ = attn.to_k(hidden_states)
value, _ = attn.to_v(hidden_states)
encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
encoder_query, _ = attn.add_q_proj(encoder_hidden_states)
encoder_key, _ = attn.add_k_proj(encoder_hidden_states)
encoder_value, _ = attn.add_v_proj(encoder_hidden_states)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_fused_projections(
attn: "FluxAttention", hidden_states, encoder_hidden_states=None
):
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
encoder_query = encoder_key = encoder_value = None
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(
encoder_hidden_states
).chunk(3, dim=-1)
return query, key, value, encoder_query, encoder_key, encoder_value
def _get_qkv_projections(
attn: "FluxAttention", hidden_states, encoder_hidden_states=None
):
if attn.fused_projections:
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
return _get_projections(attn, hidden_states, encoder_hidden_states)
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
def __init__(
self,
query_dim: int,
num_heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * num_heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else num_heads
self.added_kv_proj_dim = added_kv_proj_dim
self.added_proj_bias = added_proj_bias
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
self.to_q = ReplicatedLinear(query_dim, self.inner_dim, bias=bias)
self.to_k = ReplicatedLinear(query_dim, self.inner_dim, bias=bias)
self.to_v = ReplicatedLinear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(
ReplicatedLinear(self.inner_dim, self.out_dim, bias=out_bias)
)
if dropout != 0.0:
self.to_out.append(torch.nn.Dropout(dropout))
if added_kv_proj_dim is not None:
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
self.add_q_proj = ReplicatedLinear(
added_kv_proj_dim, self.inner_dim, bias=added_proj_bias
)
self.add_k_proj = ReplicatedLinear(
added_kv_proj_dim, self.inner_dim, bias=added_proj_bias
)
self.add_v_proj = ReplicatedLinear(
added_kv_proj_dim, self.inner_dim, bias=added_proj_bias
)
self.to_add_out = ReplicatedLinear(self.inner_dim, query_dim, bias=out_bias)
# Scaled dot product attention
self.attn = LocalAttention(
num_heads=num_heads,
head_size=self.head_dim,
dropout_rate=0,
softmax_scale=None,
causal=False,
supported_attention_backends=(
AttentionBackendEnum.FA3,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.SAGE_ATTN,
),
)
def forward(
self,
x: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
freqs_cis=None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
query, key, value, encoder_query, encoder_key, encoder_value = (
_get_qkv_projections(self, x, encoder_hidden_states)
)
query = query.unflatten(-1, (self.heads, -1))
key = key.unflatten(-1, (self.heads, -1))
value = value.unflatten(-1, (self.heads, -1))
query = self.norm_q(query)
key = self.norm_k(key)
if self.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (self.heads, -1))
encoder_key = encoder_key.unflatten(-1, (self.heads, -1))
encoder_value = encoder_value.unflatten(-1, (self.heads, -1))
encoder_query = self.norm_added_q(encoder_query)
encoder_key = self.norm_added_k(encoder_key)
bsz, seq_len, _, _ = query.shape
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if freqs_cis is not None:
cos, sin = freqs_cis
query = _apply_rotary_emb(
query, cos, sin, is_neox_style=False, interleaved=False
)
key = _apply_rotary_emb(
key, cos, sin, is_neox_style=False, interleaved=False
)
x = self.attn(query, key, value)
x = x.flatten(2, 3)
x = x.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, x = x.split_with_sizes(
[
encoder_hidden_states.shape[1],
x.shape[1] - encoder_hidden_states.shape[1],
],
dim=1,
)
x, _ = self.to_out[0](x)
if len(self.to_out) == 2:
x = self.to_out[1](x)
encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states)
return x, encoder_hidden_states
else:
return x
class FluxSingleTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = ReplicatedLinear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = ReplicatedLinear(dim + self.mlp_hidden_dim, dim)
self.attn = FluxAttention(
query_dim=dim,
dim_head=attention_head_dim,
num_heads=num_attention_heads,
out_dim=dim,
bias=True,
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
proj_hidden_states, _ = self.proj_mlp(norm_hidden_states)
mlp_hidden_states = self.act_mlp(proj_hidden_states)
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
x=norm_hidden_states,
freqs_cis=freqs_cis,
**joint_attention_kwargs,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
proj_out, _ = self.proj_out(hidden_states)
hidden_states = gate * proj_out
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
encoder_hidden_states, hidden_states = (
hidden_states[:, :text_seq_len],
hidden_states[:, text_seq_len:],
)
return encoder_hidden_states, hidden_states
class FluxTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
qk_norm: str = "rms_norm",
eps: float = 1e-6,
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
self.attn = FluxAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
num_heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
eps=eps,
)
self.norm2 = LayerNorm(dim, eps=1e-6, elementwise_affine=False)
self.ff = MLP(
input_dim=dim, mlp_hidden_dim=dim * 4, output_dim=dim, act_type="gelu"
)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = LayerNorm(dim, eps=1e-6, elementwise_affine=False)
self.ff_context = MLP(
input_dim=dim, mlp_hidden_dim=dim * 4, output_dim=dim, act_type="gelu"
)
self.ff_context = FeedForward(
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, emb=temb
)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
self.norm1_context(encoder_hidden_states, emb=temb)
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
x=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
freqs_cis=freqs_cis,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = (
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
)
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = (
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
+ c_shift_mlp[:, None]
)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = (
encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
)
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.rope = NDRotaryEmbedding(
rope_dim_list=axes_dim,
rope_theta=theta,
use_real=False,
repeat_interleave_real=False,
dtype=torch.float32 if current_platform.is_mps() else torch.float64,
)
def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
pos = ids.float()
# freqs_cos, freqs_sin = self.rope.forward(positions=pos)
freqs_cos, freqs_sin = self.rope.forward_uncached(pos=pos)
return freqs_cos.contiguous().float(), freqs_sin.contiguous().float()
class FluxTransformer2DModel(CachableDiT):
"""
The Transformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
"""
def __init__(self, config: FluxConfig, hf_config: dict[str, Any]) -> None:
super().__init__(config=config, hf_config=hf_config)
self.config = config.arch_config
self.out_channels = (
getattr(self.config, "out_channels", None) or self.config.in_channels
)
self.inner_dim = (
self.config.num_attention_heads * self.config.attention_head_dim
)
self.rotary_emb = FluxPosEmbed(theta=10000, axes_dim=self.config.axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings
if self.config.guidance_embeds
else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim,
pooled_projection_dim=self.config.pooled_projection_dim,
)
self.context_embedder = ReplicatedLinear(
self.config.joint_attention_dim, self.inner_dim
)
self.x_embedder = ReplicatedLinear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for _ in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for _ in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
)
self.proj_out = ReplicatedLinear(
self.inner_dim,
self.config.patch_size * self.config.patch_size * self.out_channels,
bias=True,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
guidance: torch.Tensor = None,
freqs_cis: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
guidance (`torch.Tensor`):
Guidance embeddings.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
"""
if (
joint_attention_kwargs is not None
and joint_attention_kwargs.get("scale", None) is not None
):
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states, _ = self.x_embedder(hidden_states)
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states, _ = self.context_embedder(encoder_hidden_states)
if (
joint_attention_kwargs is not None
and "ip_adapter_image_embeds" in joint_attention_kwargs
):
ip_adapter_image_embeds = joint_attention_kwargs.pop(
"ip_adapter_image_embeds"
)
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
for index_block, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
freqs_cis=freqs_cis,
joint_attention_kwargs=joint_attention_kwargs,
)
for index_block, block in enumerate(self.single_transformer_blocks):
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
freqs_cis=freqs_cis,
joint_attention_kwargs=joint_attention_kwargs,
)
hidden_states = self.norm_out(hidden_states, temb)
output, _ = self.proj_out(hidden_states)
return output
EntryClass = FluxTransformer2DModel
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from sglang.multimodal_gen.configs.models.dits import HunyuanVideoConfig
from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams
from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size
from sglang.multimodal_gen.runtime.layers.attention import (
LocalAttention,
UlyssesAttention,
)
from sglang.multimodal_gen.runtime.layers.layernorm import (
LayerNormScaleShift,
RMSNorm,
ScaleResidual,
ScaleResidualLayerNormScaleShift,
)
from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear
from sglang.multimodal_gen.runtime.layers.mlp import MLP
from sglang.multimodal_gen.runtime.layers.rotary_embedding import (
_apply_rotary_emb,
get_rotary_pos_embed,
)
from sglang.multimodal_gen.runtime.layers.visual_embedding import (
ModulateProjection,
PatchEmbed,
TimestepEmbedder,
unpatchify,
)
from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context
from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
from sglang.multimodal_gen.runtime.models.utils import modulate
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal DiT block with separate modulation for text and image/video,
using distributed attention and linear layers.
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
mlp_ratio: float,
dtype: torch.dtype | None = None,
supported_attention_backends: set[AttentionBackendEnum] | None = None,
prefix: str = "",
):
super().__init__()
self.deterministic = False
self.num_attention_heads = num_attention_heads
head_dim = hidden_size // num_attention_heads
mlp_hidden_dim = int(hidden_size * mlp_ratio)
# Image modulation components
self.img_mod = ModulateProjection(
hidden_size,
factor=6,
act_layer="silu",
dtype=dtype,
prefix=f"{prefix}.img_mod",
)
# Fused operations for image stream
self.img_attn_norm = LayerNormScaleShift(
hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype
)
self.img_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift(
hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype
)
self.img_mlp_residual = ScaleResidual()
# Image attention components
self.img_attn_qkv = ReplicatedLinear(
hidden_size,
hidden_size * 3,
bias=True,
params_dtype=dtype,
prefix=f"{prefix}.img_attn_qkv",
)
self.img_attn_q_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype)
self.img_attn_k_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype)
self.img_attn_proj = ReplicatedLinear(
hidden_size,
hidden_size,
bias=True,
params_dtype=dtype,
prefix=f"{prefix}.img_attn_proj",
)
self.img_mlp = MLP(
hidden_size,
mlp_hidden_dim,
bias=True,
dtype=dtype,
prefix=f"{prefix}.img_mlp",
)
# Text modulation components
self.txt_mod = ModulateProjection(
hidden_size,
factor=6,
act_layer="silu",
dtype=dtype,
prefix=f"{prefix}.txt_mod",
)
# Fused operations for text stream
self.txt_attn_norm = LayerNormScaleShift(
hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype
)
self.txt_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift(
hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype
)
self.txt_mlp_residual = ScaleResidual()
# Text attention components
self.txt_attn_qkv = ReplicatedLinear(
hidden_size, hidden_size * 3, bias=True, params_dtype=dtype
)
# QK norm layers for text
self.txt_attn_q_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype)
self.txt_attn_k_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype)
self.txt_attn_proj = ReplicatedLinear(
hidden_size, hidden_size, bias=True, params_dtype=dtype
)
self.txt_mlp = MLP(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype)
# Use UlyssesAttention to replace Distributed attention
self.attn = UlyssesAttention(
num_heads=num_attention_heads,
head_size=head_dim,
causal=False,
supported_attention_backends=supported_attention_backends,
prefix=f"{prefix}.attn",
)
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
freqs_cis: tuple,
) -> tuple[torch.Tensor, torch.Tensor]:
# Process modulation vectors
img_mod_outputs = self.img_mod(vec)
(
img_attn_shift,
img_attn_scale,
img_attn_gate,
img_mlp_shift,
img_mlp_scale,
img_mlp_gate,
) = torch.chunk(img_mod_outputs, 6, dim=-1)
txt_mod_outputs = self.txt_mod(vec)
(
txt_attn_shift,
txt_attn_scale,
txt_attn_gate,
txt_mlp_shift,
txt_mlp_scale,
txt_mlp_gate,
) = torch.chunk(txt_mod_outputs, 6, dim=-1)
# Prepare image for attention using fused operation
img_attn_input = self.img_attn_norm(img, img_attn_shift, img_attn_scale)
# Get QKV for image
img_qkv, _ = self.img_attn_qkv(img_attn_input)
batch_size, image_seq_len = img_qkv.shape[0], img_qkv.shape[1]
# Split QKV
img_qkv = img_qkv.view(
batch_size, image_seq_len, 3, self.num_attention_heads, -1
)
img_q, img_k, img_v = img_qkv[:, :, 0], img_qkv[:, :, 1], img_qkv[:, :, 2]
# Apply QK-Norm if needed
img_q = self.img_attn_q_norm(img_q.contiguous()).to(img_v)
img_k = self.img_attn_k_norm(img_k.contiguous()).to(img_v)
# Apply rotary embeddings
cos, sin = freqs_cis
img_q, img_k = _apply_rotary_emb(
img_q, cos, sin, is_neox_style=False
), _apply_rotary_emb(img_k, cos, sin, is_neox_style=False)
# Prepare text for attention using fused operation
txt_attn_input = self.txt_attn_norm(txt, txt_attn_shift, txt_attn_scale)
# Get QKV for text
txt_qkv, _ = self.txt_attn_qkv(txt_attn_input)
batch_size, text_seq_len = txt_qkv.shape[0], txt_qkv.shape[1]
# Split QKV
txt_qkv = txt_qkv.view(
batch_size, text_seq_len, 3, self.num_attention_heads, -1
)
txt_q, txt_k, txt_v = txt_qkv[:, :, 0], txt_qkv[:, :, 1], txt_qkv[:, :, 2]
# Apply QK-Norm if needed
txt_q = self.txt_attn_q_norm(txt_q.contiguous()).to(txt_q.dtype)
txt_k = self.txt_attn_k_norm(txt_k.contiguous()).to(txt_k.dtype)
# Run distributed attention
img_attn, txt_attn = self.attn(img_q, img_k, img_v, txt_q, txt_k, txt_v)
img_attn_out, _ = self.img_attn_proj(
img_attn.view(batch_size, image_seq_len, -1)
)
# Use fused operation for residual connection, normalization, and modulation
img_mlp_input, img_residual = self.img_attn_residual_mlp_norm(
img, img_attn_out, img_attn_gate, img_mlp_shift, img_mlp_scale
)
# Process image MLP
img_mlp_out = self.img_mlp(img_mlp_input)
img = self.img_mlp_residual(img_residual, img_mlp_out, img_mlp_gate)
# Process text attention output
txt_attn_out, _ = self.txt_attn_proj(
txt_attn.reshape(batch_size, text_seq_len, -1)
)
# Use fused operation for residual connection, normalization, and modulation
txt_mlp_input, txt_residual = self.txt_attn_residual_mlp_norm(
txt, txt_attn_out, txt_attn_gate, txt_mlp_shift, txt_mlp_scale
)
# Process text MLP
txt_mlp_out = self.txt_mlp(txt_mlp_input)
txt = self.txt_mlp_residual(txt_residual, txt_mlp_out, txt_mlp_gate)
return img, txt
class MMSingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers using distributed attention
and tensor parallelism.
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
mlp_ratio: float = 4.0,
dtype: torch.dtype | None = None,
supported_attention_backends: set[AttentionBackendEnum] | None = None,
prefix: str = "",
):
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
head_dim = hidden_size // num_attention_heads
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
# Combined QKV and MLP input projection
self.linear1 = ReplicatedLinear(
hidden_size,
hidden_size * 3 + mlp_hidden_dim,
bias=True,
params_dtype=dtype,
prefix=f"{prefix}.linear1",
)
# Combined projection and MLP output
self.linear2 = ReplicatedLinear(
hidden_size + mlp_hidden_dim,
hidden_size,
bias=True,
params_dtype=dtype,
prefix=f"{prefix}.linear2",
)
# QK norm layers
self.q_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype)
self.k_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype)
# Fused operations with better naming
self.input_norm_scale_shift = LayerNormScaleShift(
hidden_size,
norm_type="layer",
eps=1e-6,
elementwise_affine=False,
dtype=dtype,
)
self.output_residual = ScaleResidual()
# Activation function
self.mlp_act = nn.GELU(approximate="tanh")
# Modulation
self.modulation = ModulateProjection(
hidden_size,
factor=3,
act_layer="silu",
dtype=dtype,
prefix=f"{prefix}.modulation",
)
# Use UlyssesAttention to replace Distributed attention
self.attn = UlyssesAttention(
num_heads=num_attention_heads,
head_size=head_dim,
causal=False,
supported_attention_backends=supported_attention_backends,
prefix=f"{prefix}.attn",
)
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
freqs_cis: tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
# Process modulation
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
# Apply pre-norm and modulation using fused operation
x_mod = self.input_norm_scale_shift(x, mod_shift, mod_scale)
# Get combined projections
linear1_out, _ = self.linear1(x_mod)
# Split into QKV and MLP parts
qkv, mlp = torch.split(
linear1_out, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
# Process QKV
batch_size, seq_len = qkv.shape[0], qkv.shape[1]
qkv = qkv.view(batch_size, seq_len, 3, self.num_attention_heads, -1)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# Apply QK-Norm
q = self.q_norm(q.contiguous()).to(v.dtype)
k = self.k_norm(k.contiguous()).to(v.dtype)
# Split into image and text parts
img_q, txt_q = q[:, :-txt_len], q[:, -txt_len:]
img_k, txt_k = k[:, :-txt_len], k[:, -txt_len:]
img_v, txt_v = v[:, :-txt_len], v[:, -txt_len:]
# Apply rotary embeddings to image parts
cos, sin = freqs_cis
img_q, img_k = _apply_rotary_emb(
img_q, cos, sin, is_neox_style=False
), _apply_rotary_emb(img_k, cos, sin, is_neox_style=False)
# Run distributed attention
img_attn_output, txt_attn_output = self.attn(
img_q, img_k, img_v, txt_q, txt_k, txt_v
)
attn_output = torch.cat((img_attn_output, txt_attn_output), dim=1).view(
batch_size, seq_len, -1
)
# Process MLP activation
mlp_output = self.mlp_act(mlp)
# Combine attention and MLP outputs
combined = torch.cat((attn_output, mlp_output), dim=-1)
# Final projection
output, _ = self.linear2(combined)
# Apply residual connection with gating using fused operation
return self.output_residual(x, output, mod_gate)
class HunyuanVideoTransformer3DModel(CachableDiT):
"""
HunyuanVideo Transformer backbone adapted for distributed training.
This implementation uses distributed attention and linear layers for efficient
parallel processing across multiple GPUs.
Based on the architecture from:
- Flux.1: https://github.com/black-forest-labs/flux
- MMDiT: http://arxiv.org/abs/2403.03206
"""
# PY: we make the input args the same as HF config
# shard single stream, double stream blocks, and refiner_blocks
_fsdp_shard_conditions = HunyuanVideoConfig()._fsdp_shard_conditions
_compile_conditions = HunyuanVideoConfig()._compile_conditions
_supported_attention_backends = HunyuanVideoConfig()._supported_attention_backends
param_names_mapping = HunyuanVideoConfig().param_names_mapping
reverse_param_names_mapping = HunyuanVideoConfig().reverse_param_names_mapping
lora_param_names_mapping = HunyuanVideoConfig().lora_param_names_mapping
def __init__(self, config: HunyuanVideoConfig, hf_config: dict[str, Any]):
super().__init__(config=config, hf_config=hf_config)
self.patch_size = [config.patch_size_t, config.patch_size, config.patch_size]
self.in_channels = config.in_channels
self.num_channels_latents = config.num_channels_latents
self.out_channels = (
config.in_channels if config.out_channels is None else config.out_channels
)
self.unpatchify_channels = self.out_channels
self.guidance_embeds = config.guidance_embeds
self.rope_dim_list = list(config.rope_axes_dim)
self.rope_theta = config.rope_theta
self.text_states_dim = config.text_embed_dim
self.text_states_dim_2 = config.pooled_projection_dim
# TODO(will): hack?
self.dtype = config.dtype
pe_dim = config.hidden_size // config.num_attention_heads
if sum(config.rope_axes_dim) != pe_dim:
raise ValueError(
f"Got {config.rope_axes_dim} but expected positional dim {pe_dim}"
)
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_channels_latents = config.num_channels_latents
# Image projection
self.img_in = PatchEmbed(
self.patch_size,
self.in_channels,
self.hidden_size,
dtype=config.dtype,
prefix=f"{config.prefix}.img_in",
)
self.txt_in = SingleTokenRefiner(
self.text_states_dim,
config.hidden_size,
config.num_attention_heads,
depth=config.num_refiner_layers,
dtype=config.dtype,
prefix=f"{config.prefix}.txt_in",
)
# Time modulation
self.time_in = TimestepEmbedder(
self.hidden_size,
act_layer="silu",
dtype=config.dtype,
prefix=f"{config.prefix}.time_in",
)
# Text modulation
self.vector_in = MLP(
self.text_states_dim_2,
self.hidden_size,
self.hidden_size,
act_type="silu",
dtype=config.dtype,
prefix=f"{config.prefix}.vector_in",
)
# Guidance modulation
self.guidance_in = (
TimestepEmbedder(
self.hidden_size,
act_layer="silu",
dtype=config.dtype,
prefix=f"{config.prefix}.guidance_in",
)
if self.guidance_embeds
else None
)
# Double blocks
self.double_blocks = nn.ModuleList(
[
MMDoubleStreamBlock(
config.hidden_size,
config.num_attention_heads,
mlp_ratio=config.mlp_ratio,
dtype=config.dtype,
supported_attention_backends=self._supported_attention_backends,
prefix=f"{config.prefix}.double_blocks.{i}",
)
for i in range(config.num_layers)
]
)
# Single blocks
self.single_blocks = nn.ModuleList(
[
MMSingleStreamBlock(
config.hidden_size,
config.num_attention_heads,
mlp_ratio=config.mlp_ratio,
dtype=config.dtype,
supported_attention_backends=self._supported_attention_backends,
prefix=f"{config.prefix}.single_blocks.{i+config.num_layers}",
)
for i in range(config.num_single_layers)
]
)
self.final_layer = FinalLayer(
config.hidden_size,
self.patch_size,
self.out_channels,
dtype=config.dtype,
prefix=f"{config.prefix}.final_layer",
)
self.__post_init__()
# TODO: change the input the FORWARD_BATCH Dict
# TODO: change output to a dict
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | list[torch.Tensor],
timestep: torch.LongTensor,
encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None,
guidance=None,
**kwargs,
):
"""
Forward pass of the HunyuanDiT model.
Args:
hidden_states: Input image/video latents [B, C, T, H, W]
encoder_hidden_states: Text embeddings [B, L, D]
timestep: Diffusion timestep
guidance: Guidance scale for CFG
Returns:
Tuple of (output)
"""
forward_context = get_forward_context()
forward_batch = forward_context.forward_batch
enable_teacache = forward_batch is not None and forward_batch.enable_teacache
if guidance is None:
guidance = torch.tensor(
[6016.0], device=hidden_states.device, dtype=hidden_states.dtype
)
img = x = hidden_states
t = timestep
# Split text embeddings - first token is global, rest are per-token
if isinstance(encoder_hidden_states, torch.Tensor):
txt = encoder_hidden_states[:, 1:]
text_states_2 = encoder_hidden_states[:, 0, : self.text_states_dim_2]
else:
txt = encoder_hidden_states[0]
text_states_2 = encoder_hidden_states[1]
# Get spatial dimensions
_, _, ot, oh, ow = x.shape # codespell:ignore
tt, th, tw = (
ot // self.patch_size[0], # codespell:ignore
oh // self.patch_size[1],
ow // self.patch_size[2],
)
# Get rotary embeddings
freqs_cos, freqs_sin = get_rotary_pos_embed(
(tt * get_sp_world_size(), th, tw),
self.hidden_size,
self.num_attention_heads,
self.rope_dim_list,
self.rope_theta,
)
freqs_cos = freqs_cos.to(x.device)
freqs_sin = freqs_sin.to(x.device)
# Prepare modulation vectors
vec = self.time_in(t)
# Add text modulation
vec = vec + self.vector_in(text_states_2)
# Add guidance modulation if needed
if self.guidance_in and guidance is not None:
vec = vec + self.guidance_in(guidance)
# Embed image and text
img = self.img_in(img)
txt = self.txt_in(txt, t)
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
should_skip_forward = self.should_skip_forward_for_cached_states(
img=img, vec=vec
)
if should_skip_forward:
img = self.retrieve_cached_states(img)
else:
if enable_teacache:
original_img = img.clone()
# Process through double stream blocks
for index, block in enumerate(self.double_blocks):
double_block_args = [img, txt, vec, freqs_cis]
img, txt = block(*double_block_args)
# Merge txt and img to pass through single stream blocks
x = torch.cat((img, txt), 1)
# Process through single stream blocks
if len(self.single_blocks) > 0:
for index, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
txt_seq_len,
freqs_cis,
]
x = block(*single_block_args)
# Extract image features
img = x[:, :img_seq_len, ...]
if enable_teacache:
self.maybe_cache_states(img, original_img)
# Final layer processing
img = self.final_layer(img, vec)
# Unpatchify to get original shape
img = unpatchify(img, tt, th, tw, self.patch_size, self.out_channels)
return img
def maybe_cache_states(
self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor
) -> None:
self.previous_residual = hidden_states - original_hidden_states
def should_skip_forward_for_cached_states(self, **kwargs) -> bool:
forward_context = get_forward_context()
forward_batch = forward_context.forward_batch
if forward_batch is None:
return False
current_timestep = forward_context.current_timestep
enable_teacache = forward_batch.enable_teacache
if not enable_teacache:
return False
raise NotImplementedError("teacache is not supported yet for HunyuanVideo")
teacache_params = forward_batch.teacache_params
assert teacache_params is not None, "teacache_params is not initialized"
assert isinstance(
teacache_params, TeaCacheParams
), "teacache_params is not a TeaCacheParams"
num_inference_steps = forward_batch.num_inference_steps
teache_thresh = teacache_params.teacache_thresh
coefficients = teacache_params.coefficients
if current_timestep == 0:
self.cnt = 0
inp = kwargs["img"].clone()
vec_ = kwargs["vec"].clone()
# convert to DTensor
vec_ = torch.distributed.tensor.DTensor.from_local(
vec_,
torch.distributed.DeviceMesh(
"cuda", list(range(get_sp_world_size())), mesh_dim_names=("dp",)
),
[torch.distributed.tensor.Replicate()],
)
inp = torch.distributed.tensor.DTensor.from_local(
inp,
torch.distributed.DeviceMesh(
"cuda", list(range(get_sp_world_size())), mesh_dim_names=("dp",)
),
[torch.distributed.tensor.Replicate()],
)
# txt_ = kwargs["txt"].clone()
# inp = img.clone()
# vec_ = vec.clone()
# txt_ = txt.clone()
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = (
self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
)
normed_inp = self.double_blocks[0].img_attn_norm.norm(inp)
modulated_inp = modulate(normed_inp, shift=img_mod1_shift, scale=img_mod1_scale)
if self.cnt == 0 or self.cnt == num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [
7.33226126e02,
-4.01131952e02,
6.75869174e01,
-3.14987800e00,
9.61237896e-02,
]
rescale_func = np.poly1d(coefficients)
assert (
self.previous_modulated_input is not None
), "previous_modulated_input is not initialized"
self.accumulated_rel_l1_distance += rescale_func(
(
(modulated_inp - self.previous_modulated_input).abs().mean()
/ self.previous_modulated_input.abs().mean()
)
.cpu()
.item()
)
if self.accumulated_rel_l1_distance < teache_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
return not should_calc
def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states + self.previous_residual
class SingleTokenRefiner(nn.Module):
"""
A token refiner that processes text embeddings with attention to improve
their representation for cross-attention with image features.
"""
def __init__(
self,
in_channels,
hidden_size,
num_attention_heads,
depth=2,
qkv_bias=True,
dtype=None,
prefix: str = "",
) -> None:
super().__init__()
# Input projection
self.input_embedder = ReplicatedLinear(
in_channels,
hidden_size,
bias=True,
params_dtype=dtype,
prefix=f"{prefix}.input_embedder",
)
# Timestep embedding
self.t_embedder = TimestepEmbedder(
hidden_size, act_layer="silu", dtype=dtype, prefix=f"{prefix}.t_embedder"
)
# Context embedding
self.c_embedder = MLP(
in_channels,
hidden_size,
hidden_size,
act_type="silu",
dtype=dtype,
prefix=f"{prefix}.c_embedder",
)
# Refiner blocks
self.refiner_blocks = nn.ModuleList(
[
IndividualTokenRefinerBlock(
hidden_size,
num_attention_heads,
qkv_bias=qkv_bias,
dtype=dtype,
prefix=f"{prefix}.refiner_blocks.{i}",
)
for i in range(depth)
]
)
def forward(self, x, t):
# Get timestep embeddings
timestep_aware_representations = self.t_embedder(t)
# Get context-aware representations
context_aware_representations = torch.mean(x, dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
# Project input
x, _ = self.input_embedder(x)
# Process through refiner blocks
for block in self.refiner_blocks:
x = block(x, c)
return x
class IndividualTokenRefinerBlock(nn.Module):
"""
A transformer block for refining individual tokens with self-attention.
"""
def __init__(
self,
hidden_size,
num_attention_heads,
mlp_ratio=4.0,
qkv_bias=True,
dtype=None,
prefix: str = "",
) -> None:
super().__init__()
self.num_attention_heads = num_attention_heads
mlp_hidden_dim = int(hidden_size * mlp_ratio)
# Normalization and attention
self.norm1 = nn.LayerNorm(
hidden_size, eps=1e-6, elementwise_affine=True, dtype=dtype
)
self.self_attn_qkv = ReplicatedLinear(
hidden_size,
hidden_size * 3,
bias=qkv_bias,
params_dtype=dtype,
prefix=f"{prefix}.self_attn_qkv",
)
self.self_attn_proj = ReplicatedLinear(
hidden_size,
hidden_size,
bias=qkv_bias,
params_dtype=dtype,
prefix=f"{prefix}.self_attn_proj",
)
# MLP
self.norm2 = nn.LayerNorm(
hidden_size, eps=1e-6, elementwise_affine=True, dtype=dtype
)
self.mlp = MLP(
hidden_size,
mlp_hidden_dim,
bias=True,
act_type="silu",
dtype=dtype,
prefix=f"{prefix}.mlp",
)
# Modulation
self.adaLN_modulation = ModulateProjection(
hidden_size,
factor=2,
act_layer="silu",
dtype=dtype,
prefix=f"{prefix}.adaLN_modulation",
)
# Scaled dot product attention
self.attn = LocalAttention(
num_heads=num_attention_heads,
head_size=hidden_size // num_attention_heads,
# TODO: remove hardcode; remove STA
supported_attention_backends=(
AttentionBackendEnum.FA3,
AttentionBackendEnum.TORCH_SDPA,
),
)
def forward(self, x, c):
# Get modulation parameters
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=-1)
# Self-attention
norm_x = self.norm1(x)
qkv, _ = self.self_attn_qkv(norm_x)
batch_size, seq_len = qkv.shape[0], qkv.shape[1]
qkv = qkv.view(batch_size, seq_len, 3, self.num_attention_heads, -1)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
# Run scaled dot product attention
attn_output = self.attn(q, k, v) # [B, L, H, D]
attn_output = attn_output.reshape(batch_size, seq_len, -1) # [B, L, H*D]
# Project and apply residual connection with gating
attn_out, _ = self.self_attn_proj(attn_output)
x = x + attn_out * gate_msa.unsqueeze(1)
# MLP
mlp_out = self.mlp(self.norm2(x))
x = x + mlp_out * gate_mlp.unsqueeze(1)
return x
class FinalLayer(nn.Module):
"""
The final layer of DiT that projects features to pixel space.
"""
def __init__(
self, hidden_size, patch_size, out_channels, dtype=None, prefix: str = ""
) -> None:
super().__init__()
# Normalization
self.norm_final = nn.LayerNorm(
hidden_size, eps=1e-6, elementwise_affine=False, dtype=dtype
)
output_dim = patch_size[0] * patch_size[1] * patch_size[2] * out_channels
self.linear = ReplicatedLinear(
hidden_size,
output_dim,
bias=True,
params_dtype=dtype,
prefix=f"{prefix}.linear",
)
# Modulation
self.adaLN_modulation = ModulateProjection(
hidden_size,
factor=2,
act_layer="silu",
dtype=dtype,
prefix=f"{prefix}.adaLN_modulation",
)
def forward(self, x, c):
# What the heck HF? Why you change the scale and shift order here???
scale, shift = self.adaLN_modulation(c).chunk(2, dim=-1)
x = self.norm_final(x) * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x, _ = self.linear(x)
return x
EntryClass = HunyuanVideoTransformer3DModel
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import functools
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from diffusers.models.attention import FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.normalization import AdaLayerNormContinuous
from sglang.multimodal_gen.configs.models.dits.qwenimage import QwenImageDitConfig
from sglang.multimodal_gen.runtime.layers.attention import LocalAttention
from sglang.multimodal_gen.runtime.layers.layernorm import LayerNorm, RMSNorm
from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear
from sglang.multimodal_gen.runtime.layers.triton_ops import (
apply_rotary_embedding,
fuse_scale_shift_kernel,
)
from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__) # pylint: disable=invalid-name
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.time_proj = Timesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000
)
self.timestep_embedder = TimestepEmbedding(
in_channels=256, time_embed_dim=embedding_dim
)
def forward(self, timestep, hidden_states):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(
timesteps_proj.to(dtype=hidden_states.dtype)
) # (N, D)
conditioning = timesteps_emb
return conditioning
class QwenEmbedRope(nn.Module):
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
)
# self.rope = NDRotaryEmbedding(
# rope_dim_list=axes_dim,
# rope_theta=theta,
# use_real=False,
# repeat_interleave_real=False,
# dtype=torch.float32 if current_platform.is_mps() else torch.float64,
# )
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
device = index.device
assert dim % 2 == 0
freqs = torch.outer(
index,
(
1.0
/ torch.pow(
theta,
torch.arange(0, dim, 2, device=device).to(torch.float32).div(dim),
)
).to(device=device),
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def forward(
self,
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
txt_seq_lens: List[int],
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
A list of 3 integers [frame, height, width] representing the shape of the video.
txt_seq_lens (`List[int]`):
A list of integers of length batch_size representing the length of each text prompt.
device: (`torch.device`):
The device on which to perform the RoPE computation.
"""
# When models are initialized under a "meta" device context (e.g. init_empty_weights),
# tensors created during __init__ become meta tensors. Calling .to(...) on a meta tensor
# raises "Cannot copy out of meta tensor". Rebuild the frequencies on the target device
# in that case; otherwise move them if just on a different device.
if getattr(self.pos_freqs, "device", torch.device("meta")).type == "meta":
pos_index = torch.arange(4096, device=device)
neg_index = torch.arange(4096, device=device).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
).to(device=device)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
).to(device=device)
elif self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
# RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
video_freq = self._compute_video_freqs(frame, height, width, idx)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0).to(device=device)
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=128)
def _compute_video_freqs(
self, frame: int, height: int, width: int, idx: int = 0
) -> torch.Tensor:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = (
freqs_pos[0][idx : idx + frame]
.view(frame, 1, 1, -1)
.expand(frame, height, width, -1)
)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]],
dim=0,
)
freqs_height = freqs_height.view(1, height, 1, -1).expand(
frame, height, width, -1
)
freqs_width = torch.cat(
[freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]],
dim=0,
)
freqs_width = freqs_width.view(1, 1, width, -1).expand(
frame, height, width, -1
)
else:
freqs_height = (
freqs_pos[1][:height]
.view(1, height, 1, -1)
.expand(frame, height, width, -1)
)
freqs_width = (
freqs_pos[2][:width]
.view(1, 1, width, -1)
.expand(frame, height, width, -1)
)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(
seq_lens, -1
)
return freqs.clone().contiguous()
class QwenImageCrossAttention(nn.Module):
def __init__(
self,
dim: int, # query_dim
num_heads: int,
head_dim: int,
window_size=(-1, -1),
added_kv_proj_dim: int = None,
out_bias: bool = True,
qk_norm=True, # rmsnorm
eps=1e-6,
pre_only=False,
context_pre_only: bool = False,
parallel_attention=False,
out_dim: int = None,
) -> None:
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
self.parallel_attention = parallel_attention
# layers
self.to_q = ReplicatedLinear(dim, dim)
self.to_k = ReplicatedLinear(dim, dim)
self.to_v = ReplicatedLinear(dim, dim)
if self.qk_norm:
self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
self.inner_dim = out_dim if out_dim is not None else head_dim * num_heads
self.inner_kv_dim = self.inner_dim
if added_kv_proj_dim is not None:
self.add_k_proj = ReplicatedLinear(
added_kv_proj_dim, self.inner_kv_dim, bias=True
)
self.add_v_proj = ReplicatedLinear(
added_kv_proj_dim, self.inner_kv_dim, bias=True
)
if context_pre_only is not None:
self.add_q_proj = ReplicatedLinear(
added_kv_proj_dim, self.inner_dim, bias=True
)
if context_pre_only is not None and not context_pre_only:
self.to_add_out = ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias)
else:
self.to_add_out = None
if not pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(
ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias)
)
else:
self.to_out = None
self.norm_added_q = RMSNorm(head_dim, eps=eps)
self.norm_added_k = RMSNorm(head_dim, eps=eps)
# Scaled dot product attention
self.attn = LocalAttention(
num_heads=num_heads,
head_size=self.head_dim,
dropout_rate=0,
softmax_scale=None,
causal=False,
supported_attention_backends={
AttentionBackendEnum.FA3,
AttentionBackendEnum.TORCH_SDPA,
},
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor],
**cross_attention_kwargs,
):
seq_txt = encoder_hidden_states.shape[1]
# Compute QKV for image stream (sample projections)
img_query, _ = self.to_q(hidden_states)
img_key, _ = self.to_k(hidden_states)
img_value, _ = self.to_v(hidden_states)
# Compute QKV for text stream (context projections)
txt_query, _ = self.add_q_proj(encoder_hidden_states)
txt_key, _ = self.add_k_proj(encoder_hidden_states)
txt_value, _ = self.add_v_proj(encoder_hidden_states)
# Reshape for multi-head attention
img_query = img_query.unflatten(-1, (self.num_heads, -1))
img_key = img_key.unflatten(-1, (self.num_heads, -1))
img_value = img_value.unflatten(-1, (self.num_heads, -1))
txt_query = txt_query.unflatten(-1, (self.num_heads, -1))
txt_key = txt_key.unflatten(-1, (self.num_heads, -1))
txt_value = txt_value.unflatten(-1, (self.num_heads, -1))
# Apply QK normalization
if self.norm_q is not None:
img_query = self.norm_q(img_query)
if self.norm_k is not None:
img_key = self.norm_k(img_key)
if self.norm_added_q is not None:
txt_query = self.norm_added_q(txt_query)
if self.norm_added_k is not None:
txt_key = self.norm_added_k(txt_key)
# Apply RoPE
if image_rotary_emb is not None:
(img_cos, img_sin), (txt_cos, txt_sin) = image_rotary_emb
img_query = apply_rotary_embedding(
img_query, img_cos, img_sin, interleaved=True
)
img_key = apply_rotary_embedding(
img_key, img_cos, img_sin, interleaved=True
)
txt_query = apply_rotary_embedding(
txt_query, txt_cos, txt_sin, interleaved=True
)
txt_key = apply_rotary_embedding(
txt_key, txt_cos, txt_sin, interleaved=True
)
# Concatenate for joint attention
# Order: [text, image]
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
# Compute joint attention
joint_hidden_states = self.attn(
joint_query,
joint_key,
joint_value,
)
# Reshape back
joint_hidden_states = joint_hidden_states.flatten(2, 3)
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
# Split attention outputs back
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
# Apply output projections
img_attn_output, _ = self.to_out[0](img_attn_output)
if len(self.to_out) > 1:
(img_attn_output,) = self.to_out[1](img_attn_output) # dropout
txt_attn_output, _ = self.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
class QwenImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
qk_norm: str = "rms_norm",
eps: float = 1e-6,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
# Image processing modules
self.img_mod = nn.Sequential(
nn.SiLU(),
nn.Linear(
dim, 6 * dim, bias=True
), # For scale, shift, gate for norm1 and norm2
)
self.img_norm1 = LayerNorm(dim, elementwise_affine=False, eps=eps)
self.attn = QwenImageCrossAttention(
dim=dim,
num_heads=num_attention_heads,
added_kv_proj_dim=dim,
context_pre_only=False,
head_dim=attention_head_dim,
)
self.img_norm2 = LayerNorm(dim, eps=eps, elementwise_affine=False)
self.img_mlp = FeedForward(
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
)
# Text processing modules
self.txt_mod = nn.Sequential(
nn.SiLU(),
nn.Linear(
dim, 6 * dim, bias=True
), # For scale, shift, gate for norm1 and norm2
)
self.txt_norm1 = LayerNorm(dim, elementwise_affine=False, eps=eps)
# Text doesn't need separate attention - it's handled by img_attn joint computation
self.txt_norm2 = LayerNorm(dim, elementwise_affine=False, eps=eps)
self.txt_mlp = FeedForward(
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
)
def _modulate(self, x, mod_params):
"""Apply modulation to input tensor"""
shift, scale, gate = mod_params.chunk(3, dim=-1)
return fuse_scale_shift_kernel(x, scale, shift), gate.unsqueeze(1)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Get modulation parameters for both streams
img_mod_params = self.img_mod(temb) # [B, 6*dim]
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
# Split modulation parameters for norm1 and norm2
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
# Process image stream - norm1 + modulation
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
# Process text stream - norm1 + modulation
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
# Use QwenAttnProcessor2_0 for joint attention computation
# This directly implements the DoubleStreamLayerMegatron logic:
# 1. Computes QKV for both streams
# 2. Applies QK normalization and RoPE
# 3. Concatenates and runs joint attention
# 4. Splits results back to separate streams
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=img_modulated, # Image stream (will be processed as "sample")
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
img_attn_output, txt_attn_output = attn_output
# Apply attention gates and add residual (like in Megatron)
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# Process image stream - norm2 + MLP
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
img_mlp_output = self.img_mlp(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output
# Process text stream - norm2 + MLP
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
txt_mlp_output = self.txt_mlp(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
# Clip to prevent overflow for fp16
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class QwenImageTransformer2DModel(CachableDiT):
"""
The Transformer model introduced in Qwen.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["QwenImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["QwenImageTransformerBlock"]
def __init__(
self,
config: QwenImageDitConfig,
hf_config: dict[str, Any],
):
super().__init__(config=config, hf_config=hf_config)
patch_size = config.arch_config.patch_size
in_channels = config.arch_config.in_channels
out_channels = config.arch_config.out_channels
num_layers = config.arch_config.num_layers
attention_head_dim = config.arch_config.attention_head_dim
num_attention_heads = config.arch_config.num_attention_heads
joint_attention_dim = config.arch_config.joint_attention_dim
axes_dims_rope = config.arch_config.axes_dims_rope
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.rotary_emb = QwenEmbedRope(
theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True
)
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
self.img_in = nn.Linear(in_channels, self.inner_dim)
self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
QwenImageTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for _ in range(num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
)
self.proj_out = nn.Linear(
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
encoder_hidden_states_mask: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
txt_seq_lens: Optional[List[int]] = None,
freqs_cis: tuple[torch.Tensor, torch.Tensor] = None,
guidance: torch.Tensor = None, # TODO: this should probably be removed
attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`QwenTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
Mask of the input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if (
attention_kwargs is not None
and attention_kwargs.get("scale", None) is not None
):
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
if isinstance(encoder_hidden_states, list):
encoder_hidden_states = encoder_hidden_states[0]
hidden_states = self.img_in(hidden_states)
timestep = (timestep / 1000).to(hidden_states.dtype)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
temb = self.time_text_embed(timestep, hidden_states)
image_rotary_emb = freqs_cis
for index_block, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(
controlnet_block_samples
)
interval_control = int(np.ceil(interval_control))
hidden_states = (
hidden_states
+ controlnet_block_samples[index_block // interval_control]
)
# Use only the image part (hidden_states) from the dual-stream blocks
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
return output
EntryClass = QwenImageTransformer2DModel
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
from typing import Any
import torch
from einops import rearrange, repeat
from torch import nn
from sglang.multimodal_gen.configs.models.dits import StepVideoConfig
from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size
from sglang.multimodal_gen.runtime.layers.attention import LocalAttention, USPAttention
from sglang.multimodal_gen.runtime.layers.layernorm import LayerNormScaleShift
from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear
from sglang.multimodal_gen.runtime.layers.mlp import MLP
from sglang.multimodal_gen.runtime.layers.rotary_embedding import (
_apply_rotary_emb,
get_rotary_pos_embed,
)
from sglang.multimodal_gen.runtime.layers.visual_embedding import TimestepEmbedder
from sglang.multimodal_gen.runtime.models.dits.base import BaseDiT
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
class PatchEmbed2D(nn.Module):
"""2D Image to Patch Embedding
Image to Patch Embedding using Conv2d
A convolution based approach to patchifying a 2D image w/ embedding projection.
Based on the impl in https://github.com/google-research/vision_transformer
Hacked together by / Copyright 2020 Ross Wightman
Remove the _assert function in forward function to be compatible with multi-resolution images.
"""
def __init__(
self,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
dtype=None,
prefix: str = "",
):
super().__init__()
# Convert patch_size to 2-tuple
if isinstance(patch_size, list | tuple):
if len(patch_size) == 1:
patch_size = (patch_size[0], patch_size[0])
else:
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.flatten = flatten
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
dtype=dtype,
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class StepVideoRMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x) -> torch.Tensor:
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
class SelfAttention(nn.Module):
def __init__(
self,
hidden_dim,
head_dim,
rope_split: tuple[int, int, int] = (64, 32, 32),
bias: bool = False,
with_rope: bool = True,
with_qk_norm: bool = True,
attn_type: str = "torch",
supported_attention_backends=(
AttentionBackendEnum.FA3,
AttentionBackendEnum.TORCH_SDPA,
),
):
super().__init__()
self.head_dim = head_dim
self.hidden_dim = hidden_dim
self.rope_split = list(rope_split)
self.n_heads = hidden_dim // head_dim
self.wqkv = ReplicatedLinear(hidden_dim, hidden_dim * 3, bias=bias)
self.wo = ReplicatedLinear(hidden_dim, hidden_dim, bias=bias)
self.with_rope = with_rope
self.with_qk_norm = with_qk_norm
if self.with_qk_norm:
self.q_norm = StepVideoRMSNorm(head_dim, elementwise_affine=True)
self.k_norm = StepVideoRMSNorm(head_dim, elementwise_affine=True)
# self.core_attention = self.attn_processor(attn_type=attn_type)
self.parallel = attn_type == "parallel"
self.attn = USPAttention(
num_heads=self.n_heads,
head_size=head_dim,
causal=False,
supported_attention_backends=supported_attention_backends,
)
def _apply_rope(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
"""
x: [B, S, H, D]
cos: [S, D/2] where D = head_dim = sum(self.rope_split)
sin: [S, D/2]
returns x with rotary applied exactly as v0 did
"""
B, S, H, D = x.shape
# 1) split cos/sin per chunk
half_splits = [c // 2 for c in self.rope_split] # [32,16,16] for [64,32,32]
cos_splits = cos.split(half_splits, dim=1)
sin_splits = sin.split(half_splits, dim=1)
outs = []
idx = 0
for chunk_size, cos_i, sin_i in zip(
self.rope_split, cos_splits, sin_splits, strict=True
):
# slice the corresponding channels
x_chunk = x[..., idx : idx + chunk_size] # [B,S,H,chunk_size]
idx += chunk_size
# flatten to [S, B*H, chunk_size]
x_flat = rearrange(x_chunk, "b s h d -> s (b h) d")
# apply rotary on *that* chunk
out_flat = _apply_rotary_emb(x_flat, cos_i, sin_i, is_neox_style=True)
# restore [B,S,H,chunk_size]
out = rearrange(out_flat, "s (b h) d -> b s h d", b=B, h=H)
outs.append(out)
# concatenate back to [B,S,H,D]
return torch.cat(outs, dim=-1)
def forward(
self,
x,
cu_seqlens=None,
max_seqlen=None,
rope_positions=None,
cos_sin=None,
attn_mask=None,
mask_strategy=None,
):
B, S, _ = x.shape
xqkv, _ = self.wqkv(x)
xqkv = xqkv.view(*x.shape[:-1], self.n_heads, 3 * self.head_dim)
q, k, v = torch.split(xqkv, [self.head_dim] * 3, dim=-1) # [B,S,H,D]
if self.with_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
if self.with_rope:
if rope_positions is not None:
F, Ht, W = rope_positions
assert F * Ht * W == S, "rope_positions mismatches sequence length"
cos, sin = cos_sin
cos = cos.to(x.device, dtype=x.dtype)
sin = sin.to(x.device, dtype=x.dtype)
q = self._apply_rope(q, cos, sin)
k = self._apply_rope(k, cos, sin)
output, _ = self.attn(q, k, v) # [B,heads,S,D]
output = rearrange(output, "b s h d -> b s (h d)")
output, _ = self.wo(output)
return output
class CrossAttention(nn.Module):
def __init__(
self,
hidden_dim,
head_dim,
bias=False,
with_qk_norm=True,
supported_attention_backends=(
AttentionBackendEnum.FA3,
AttentionBackendEnum.TORCH_SDPA,
),
) -> None:
super().__init__()
self.head_dim = head_dim
self.n_heads = hidden_dim // head_dim
self.wq = ReplicatedLinear(hidden_dim, hidden_dim, bias=bias)
self.wkv = ReplicatedLinear(hidden_dim, hidden_dim * 2, bias=bias)
self.wo = ReplicatedLinear(hidden_dim, hidden_dim, bias=bias)
self.with_qk_norm = with_qk_norm
if self.with_qk_norm:
self.q_norm = StepVideoRMSNorm(head_dim, elementwise_affine=True)
self.k_norm = StepVideoRMSNorm(head_dim, elementwise_affine=True)
self.attn = LocalAttention(
num_heads=self.n_heads,
head_size=head_dim,
causal=False,
supported_attention_backends=supported_attention_backends,
)
def forward(
self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, attn_mask=None
) -> torch.Tensor:
xq, _ = self.wq(x)
xq = xq.view(*xq.shape[:-1], self.n_heads, self.head_dim)
xkv, _ = self.wkv(encoder_hidden_states)
xkv = xkv.view(*xkv.shape[:-1], self.n_heads, 2 * self.head_dim)
xk, xv = torch.split(xkv, [self.head_dim] * 2, dim=-1) ## seq_len, n, dim
if self.with_qk_norm:
xq = self.q_norm(xq)
xk = self.k_norm(xk)
output = self.attn(xq, xk, xv)
output = rearrange(output, "b s h d -> b s (h d)")
output, _ = self.wo(output)
return output
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, time_step_rescale=1000):
super().__init__()
self.emb = TimestepEmbedder(embedding_dim)
self.silu = nn.SiLU()
self.linear = ReplicatedLinear(embedding_dim, 6 * embedding_dim, bias=True)
self.time_step_rescale = time_step_rescale ## timestep usually in [0, 1], we rescale it to [0,1000] for stability
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: dict[str, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
embedded_timestep = self.emb(timestep * self.time_step_rescale)
out, _ = self.linear(self.silu(embedded_timestep))
return out, embedded_timestep
class StepVideoTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
attention_head_dim: int,
norm_eps: float = 1e-5,
ff_inner_dim: int | None = None,
ff_bias: bool = False,
attention_type: str = "torch",
):
super().__init__()
self.dim = dim
self.norm1 = LayerNormScaleShift(
dim, norm_type="layer", elementwise_affine=True, eps=norm_eps
)
self.attn1 = SelfAttention(
dim,
attention_head_dim,
bias=False,
with_rope=True,
with_qk_norm=True,
)
self.norm2 = LayerNormScaleShift(
dim, norm_type="layer", elementwise_affine=True, eps=norm_eps
)
self.attn2 = CrossAttention(
dim, attention_head_dim, bias=False, with_qk_norm=True
)
self.ff = MLP(
input_dim=dim,
mlp_hidden_dim=dim * 4 if ff_inner_dim is None else ff_inner_dim,
act_type="gelu_pytorch_tanh",
bias=ff_bias,
)
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
@torch.no_grad()
def forward(
self,
q: torch.Tensor,
kv: torch.Tensor,
t_expand: torch.LongTensor,
attn_mask=None,
rope_positions: list | None = None,
cos_sin=None,
mask_strategy=None,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
torch.clone(chunk)
for chunk in (
self.scale_shift_table[None] + t_expand.reshape(-1, 6, self.dim)
).chunk(6, dim=1)
)
scale_shift_q = self.norm1(
q, scale=scale_msa.squeeze(1), shift=shift_msa.squeeze(1)
)
attn_q = self.attn1(
scale_shift_q,
rope_positions=rope_positions,
cos_sin=cos_sin,
mask_strategy=mask_strategy,
)
q = attn_q * gate_msa + q
attn_q = self.attn2(q, kv, attn_mask)
q = attn_q + q
scale_shift_q = self.norm2(
q, scale=scale_mlp.squeeze(1), shift=shift_mlp.squeeze(1)
)
ff_output = self.ff(scale_shift_q)
q = ff_output * gate_mlp + q
return q
class StepVideoModel(BaseDiT):
# (Optional) Keep the same attribute for compatibility with splitting, etc.
_fsdp_shard_conditions = [
lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit(),
# lambda n, m: "pos_embed" in n # If needed for the patch embedding.
]
param_names_mapping = StepVideoConfig().param_names_mapping
reverse_param_names_mapping = StepVideoConfig().reverse_param_names_mapping
lora_param_names_mapping = StepVideoConfig().lora_param_names_mapping
_supported_attention_backends = StepVideoConfig()._supported_attention_backends
def __init__(self, config: StepVideoConfig, hf_config: dict[str, Any]) -> None:
super().__init__(config=config, hf_config=hf_config)
self.num_attention_heads = config.num_attention_heads
self.attention_head_dim = config.attention_head_dim
self.in_channels = config.in_channels
self.out_channels = config.out_channels
self.num_layers = config.num_layers
self.dropout = config.dropout
self.patch_size = config.patch_size
self.norm_type = config.norm_type
self.norm_elementwise_affine = config.norm_elementwise_affine
self.norm_eps = config.norm_eps
self.use_additional_conditions = config.use_additional_conditions
self.caption_channels = config.caption_channels
self.attention_type = config.attention_type
self.num_channels_latents = config.num_channels_latents
# Compute inner dimension.
self.hidden_size = config.hidden_size
# Image/video patch embedding.
self.pos_embed = PatchEmbed2D(
patch_size=self.patch_size,
in_chans=self.in_channels,
embed_dim=self.hidden_size,
)
self._rope_cache: dict[tuple, tuple[torch.Tensor, torch.Tensor]] = {}
# Transformer blocks.
self.transformer_blocks = nn.ModuleList(
[
StepVideoTransformerBlock(
dim=self.hidden_size,
attention_head_dim=self.attention_head_dim,
attention_type=self.attention_type,
)
for _ in range(self.num_layers)
]
)
# Output blocks.
self.norm_out = LayerNormScaleShift(
self.hidden_size,
norm_type="layer",
eps=self.norm_eps,
elementwise_affine=self.norm_elementwise_affine,
)
self.scale_shift_table = nn.Parameter(
torch.randn(2, self.hidden_size) / (self.hidden_size**0.5)
)
self.proj_out = ReplicatedLinear(
self.hidden_size, self.patch_size * self.patch_size * self.out_channels
)
# Time modulation via adaptive layer norm.
self.adaln_single = AdaLayerNormSingle(self.hidden_size)
# Set up caption conditioning.
if isinstance(self.caption_channels, int):
caption_channel = self.caption_channels
else:
caption_channel, clip_channel = self.caption_channels
self.clip_projection = ReplicatedLinear(clip_channel, self.hidden_size)
self.caption_norm = nn.LayerNorm(
caption_channel,
eps=self.norm_eps,
elementwise_affine=self.norm_elementwise_affine,
)
self.caption_projection = MLP(
input_dim=caption_channel,
mlp_hidden_dim=self.hidden_size,
act_type="gelu_pytorch_tanh",
)
# Flag to indicate if using parallel attention.
self.parallel = self.attention_type == "parallel"
self.__post_init__()
def patchfy(self, hidden_states) -> torch.Tensor:
hidden_states = rearrange(hidden_states, "b f c h w -> (b f) c h w")
hidden_states = self.pos_embed(hidden_states)
return hidden_states
def prepare_attn_mask(
self, encoder_attention_mask, encoder_hidden_states, q_seqlen
) -> tuple[torch.Tensor, torch.Tensor]:
kv_seqlens = encoder_attention_mask.sum(dim=1).int()
mask = torch.zeros(
[len(kv_seqlens), q_seqlen, max(kv_seqlens)],
dtype=torch.bool,
device=encoder_attention_mask.device,
)
encoder_hidden_states = encoder_hidden_states[:, : max(kv_seqlens)]
for i, kv_len in enumerate(kv_seqlens):
mask[i, :, :kv_len] = 1
return encoder_hidden_states, mask
def block_forward(
self,
hidden_states,
encoder_hidden_states=None,
t_expand=None,
rope_positions=None,
cos_sin=None,
attn_mask=None,
parallel=True,
mask_strategy=None,
) -> torch.Tensor:
for i, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states,
encoder_hidden_states,
t_expand=t_expand,
attn_mask=attn_mask,
rope_positions=rope_positions,
cos_sin=cos_sin,
mask_strategy=mask_strategy[i],
)
return hidden_states
def _get_rope(
self,
rope_positions: tuple[int, int, int],
dtype: torch.dtype,
device: torch.device,
):
F, Ht, W = rope_positions
key = (F, Ht, W, dtype)
if key not in self._rope_cache:
cos, sin = get_rotary_pos_embed(
rope_sizes=(F * get_sp_world_size(), Ht, W),
hidden_size=self.hidden_size,
heads_num=self.hidden_size // self.attention_head_dim,
rope_dim_list=(64, 32, 32), # same split you used
rope_theta=1.0e4,
dtype=torch.float32, # build once in fp32
)
# move & cast once
self._rope_cache[key] = (
cos.to(device, dtype=dtype),
sin.to(device, dtype=dtype),
)
return self._rope_cache[key]
@torch.inference_mode()
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
t_expand: torch.LongTensor | None = None,
encoder_hidden_states_2: torch.Tensor | None = None,
added_cond_kwargs: dict[str, torch.Tensor] | None = None,
encoder_attention_mask: torch.Tensor | None = None,
fps: torch.Tensor | None = None,
return_dict: bool = True,
mask_strategy=None,
guidance=None,
):
assert hidden_states.ndim == 5
"hidden_states's shape should be (bsz, f, ch, h ,w)"
frame = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> b f c h w", f=frame)
if mask_strategy is None:
mask_strategy = [None, None]
bsz, frame, _, height, width = hidden_states.shape
height, width = height // self.patch_size, width // self.patch_size
hidden_states = self.patchfy(hidden_states)
len_frame = hidden_states.shape[1]
t_expand, embedded_timestep = self.adaln_single(t_expand)
encoder_hidden_states = self.caption_projection(
self.caption_norm(encoder_hidden_states)
)
if encoder_hidden_states_2 is not None and hasattr(self, "clip_projection"):
clip_embedding, _ = self.clip_projection(encoder_hidden_states_2)
encoder_hidden_states = torch.cat(
[clip_embedding, encoder_hidden_states], dim=1
)
hidden_states = rearrange(
hidden_states, "(b f) l d-> b (f l) d", b=bsz, f=frame, l=len_frame
).contiguous()
encoder_hidden_states, attn_mask = self.prepare_attn_mask(
encoder_attention_mask, encoder_hidden_states, q_seqlen=frame * len_frame
)
cos_sin = self._get_rope(
(frame, height, width), hidden_states.dtype, hidden_states.device
)
hidden_states = self.block_forward(
hidden_states,
encoder_hidden_states,
t_expand=t_expand,
rope_positions=[frame, height, width],
cos_sin=cos_sin,
attn_mask=attn_mask,
parallel=self.parallel,
mask_strategy=mask_strategy,
)
hidden_states = rearrange(
hidden_states, "b (f l) d -> (b f) l d", b=bsz, f=frame, l=len_frame
)
embedded_timestep = repeat(
embedded_timestep, "b d -> (b f) d", f=frame
).contiguous()
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None]
).chunk(2, dim=1)
hidden_states = self.norm_out(
hidden_states, shift=shift.squeeze(1), scale=scale.squeeze(1)
)
# Modulation
hidden_states, _ = self.proj_out(hidden_states)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(
-1,
height,
width,
self.patch_size,
self.patch_size,
self.out_channels,
)
)
hidden_states = rearrange(hidden_states, "n h w p q c -> n c h p w q")
output = hidden_states.reshape(
shape=(
-1,
self.out_channels,
height * self.patch_size,
width * self.patch_size,
)
)
output = rearrange(output, "(b f) c h w -> b c f h w", f=frame)
return output
EntryClass = StepVideoModel
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from sglang.multimodal_gen.configs.models.dits import WanVideoConfig
from sglang.multimodal_gen.configs.sample.wan import WanTeaCacheParams
from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size
from sglang.multimodal_gen.runtime.layers.attention import (
LocalAttention,
UlyssesAttention_VSA,
USPAttention,
)
from sglang.multimodal_gen.runtime.layers.layernorm import (
FP32LayerNorm,
LayerNormScaleShift,
RMSNorm,
ScaleResidual,
ScaleResidualLayerNormScaleShift,
)
from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear
from sglang.multimodal_gen.runtime.layers.mlp import MLP
from sglang.multimodal_gen.runtime.layers.rotary_embedding import (
NDRotaryEmbedding,
_apply_rotary_emb,
)
from sglang.multimodal_gen.runtime.layers.visual_embedding import (
ModulateProjection,
PatchEmbed,
TimestepEmbedder,
)
from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context
from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT
from sglang.multimodal_gen.runtime.platforms import (
AttentionBackendEnum,
current_platform,
)
from sglang.multimodal_gen.runtime.server_args import get_global_server_args
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class WanImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.norm1 = FP32LayerNorm(in_features)
self.ff = MLP(in_features, in_features, out_features, act_type="gelu")
self.norm2 = FP32LayerNorm(out_features)
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
dtype = encoder_hidden_states_image.dtype
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states).to(dtype)
return hidden_states
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
text_embed_dim: int,
image_embed_dim: int | None = None,
):
super().__init__()
self.time_embedder = TimestepEmbedder(
dim, frequency_embedding_size=time_freq_dim, act_layer="silu"
)
self.time_modulation = ModulateProjection(dim, factor=6, act_layer="silu")
self.text_embedder = MLP(
text_embed_dim, dim, dim, bias=True, act_type="gelu_pytorch_tanh"
)
self.image_embedder = None
if image_embed_dim is not None:
self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
def forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: torch.Tensor | None = None,
timestep_seq_len: int | None = None,
):
temb = self.time_embedder(timestep, timestep_seq_len)
timestep_proj = self.time_modulation(temb)
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None:
assert self.image_embedder is not None
encoder_hidden_states_image = self.image_embedder(
encoder_hidden_states_image
)
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
class WanSelfAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6,
parallel_attention=False,
supported_attention_backends: set[AttentionBackendEnum] | None = None,
) -> None:
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
self.parallel_attention = parallel_attention
# layers
self.to_q = ReplicatedLinear(dim, dim)
self.to_k = ReplicatedLinear(dim, dim)
self.to_v = ReplicatedLinear(dim, dim)
self.to_out = ReplicatedLinear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
# Scaled dot product attention
self.attn = LocalAttention(
num_heads=num_heads,
head_size=self.head_dim,
dropout_rate=0,
softmax_scale=None,
causal=False,
supported_attention_backends=supported_attention_backends,
)
def forward(self, x: torch.Tensor, context: torch.Tensor, context_lens: int):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
pass
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context, context_lens, crossattn_cache=None):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.to_q(x)[0]).view(b, -1, n, d)
if crossattn_cache is not None:
if not crossattn_cache["is_init"]:
crossattn_cache["is_init"] = True
k = self.norm_k(self.to_k(context)[0]).view(b, -1, n, d)
v = self.to_v(context)[0].view(b, -1, n, d)
crossattn_cache["k"] = k
crossattn_cache["v"] = v
else:
k = crossattn_cache["k"]
v = crossattn_cache["v"]
else:
k = self.norm_k(self.to_k(context)[0]).view(b, -1, n, d)
v = self.to_v(context)[0].view(b, -1, n, d)
# compute attention
x = self.attn(q, k, v)
# output
x = x.flatten(2)
x, _ = self.to_out(x)
return x
class WanI2VCrossAttention(WanSelfAttention):
def __init__(
self,
dim: int,
num_heads: int,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6,
supported_attention_backends: set[AttentionBackendEnum] | None = None,
) -> None:
# VSA should not be in supported_attention_backends
super().__init__(
dim,
num_heads,
window_size,
qk_norm,
eps,
supported_attention_backends=supported_attention_backends,
)
self.add_k_proj = ReplicatedLinear(dim, dim)
self.add_v_proj = ReplicatedLinear(dim, dim)
self.norm_added_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_added_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, context, context_lens):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
context_img = context[:, :257]
context = context[:, 257:]
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.to_q(x)[0]).view(b, -1, n, d)
k = self.norm_k(self.to_k(context)[0]).view(b, -1, n, d)
v = self.to_v(context)[0].view(b, -1, n, d)
k_img = self.norm_added_k(self.add_k_proj(context_img)[0]).view(b, -1, n, d)
v_img = self.add_v_proj(context_img)[0].view(b, -1, n, d)
img_x = self.attn(q, k_img, v_img)
# compute attention
x = self.attn(q, k, v)
# output
x = x.flatten(2)
img_x = img_x.flatten(2)
x = x + img_x
x, _ = self.to_out(x)
return x
class WanTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
qk_norm: str = "rms_norm_across_heads",
cross_attn_norm: bool = False,
eps: float = 1e-6,
added_kv_proj_dim: int | None = None,
supported_attention_backends: set[AttentionBackendEnum] | None = None,
prefix: str = "",
):
super().__init__()
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.to_q = ReplicatedLinear(dim, dim, bias=True)
self.to_k = ReplicatedLinear(dim, dim, bias=True)
self.to_v = ReplicatedLinear(dim, dim, bias=True)
self.to_out = ReplicatedLinear(dim, dim, bias=True)
self.attn1 = USPAttention(
num_heads=num_heads,
head_size=dim // num_heads,
causal=False,
supported_attention_backends=supported_attention_backends,
prefix=f"{prefix}.attn1",
)
self.hidden_dim = dim
self.num_attention_heads = num_heads
dim_head = dim // num_heads
if qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
elif qk_norm == "rms_norm_across_heads":
# LTX applies qk norm across all heads
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
else:
logger.error("QK Norm type not supported")
raise Exception
assert cross_attn_norm is True
self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=True,
dtype=torch.float32,
compute_dtype=torch.float32,
)
# 2. Cross-attention
if added_kv_proj_dim is not None:
# I2V
self.attn2 = WanI2VCrossAttention(
dim,
num_heads,
qk_norm=qk_norm,
eps=eps,
supported_attention_backends=supported_attention_backends,
)
else:
# T2V
self.attn2 = WanT2VCrossAttention(
dim,
num_heads,
qk_norm=qk_norm,
eps=eps,
supported_attention_backends=supported_attention_backends,
)
self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=False,
dtype=torch.float32,
compute_dtype=torch.float32,
)
# 3. Feed-forward
self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh")
self.mlp_residual = ScaleResidual()
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
freqs_cis: tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
if hidden_states.dim() == 4:
hidden_states = hidden_states.squeeze(1)
bs, seq_length, _ = hidden_states.shape
orig_dtype = hidden_states.dtype
if temb.dim() == 4:
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table.unsqueeze(0) + temb.float()
).chunk(6, dim=2)
# batch_size, seq_len, 1, inner_dim
shift_msa = shift_msa.squeeze(2)
scale_msa = scale_msa.squeeze(2)
gate_msa = gate_msa.squeeze(2)
c_shift_msa = c_shift_msa.squeeze(2)
c_scale_msa = c_scale_msa.squeeze(2)
c_gate_msa = c_gate_msa.squeeze(2)
else:
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
e = self.scale_shift_table + temb.float()
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
e.chunk(6, dim=1)
)
assert shift_msa.dtype == torch.float32
# 1. Self-attention
norm_hidden_states = (
self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa
).to(orig_dtype)
query, _ = self.to_q(norm_hidden_states)
key, _ = self.to_k(norm_hidden_states)
value, _ = self.to_v(norm_hidden_states)
if self.norm_q is not None:
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k(key)
query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
# Apply rotary embeddings
cos, sin = freqs_cis
query, key = _apply_rotary_emb(
query, cos, sin, is_neox_style=False
), _apply_rotary_emb(key, cos, sin, is_neox_style=False)
attn_output, _ = self.attn1(query, key, value)
attn_output = attn_output.flatten(2)
attn_output, _ = self.to_out(attn_output)
attn_output = attn_output.squeeze(1)
null_shift = null_scale = torch.zeros(
(1,), device=hidden_states.device, dtype=hidden_states.dtype
)
norm_hidden_states, hidden_states = self.self_attn_residual_norm(
hidden_states, attn_output, gate_msa, null_shift, null_scale
)
norm_hidden_states, hidden_states = norm_hidden_states.to(
orig_dtype
), hidden_states.to(orig_dtype)
# 2. Cross-attention
attn_output = self.attn2(
norm_hidden_states, context=encoder_hidden_states, context_lens=None
)
norm_hidden_states, hidden_states = self.cross_attn_residual_norm(
hidden_states, attn_output, 1, c_shift_msa, c_scale_msa
)
norm_hidden_states, hidden_states = norm_hidden_states.to(
orig_dtype
), hidden_states.to(orig_dtype)
# 3. Feed-forward
ff_output = self.ffn(norm_hidden_states)
hidden_states = self.mlp_residual(hidden_states, ff_output, c_gate_msa)
hidden_states = hidden_states.to(orig_dtype)
return hidden_states
class WanTransformerBlock_VSA(nn.Module):
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
qk_norm: str = "rms_norm_across_heads",
cross_attn_norm: bool = False,
eps: float = 1e-6,
added_kv_proj_dim: int | None = None,
supported_attention_backends: set[AttentionBackendEnum] | None = None,
prefix: str = "",
):
super().__init__()
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.to_q = ReplicatedLinear(dim, dim, bias=True)
self.to_k = ReplicatedLinear(dim, dim, bias=True)
self.to_v = ReplicatedLinear(dim, dim, bias=True)
self.to_gate_compress = ReplicatedLinear(dim, dim, bias=True)
self.to_out = ReplicatedLinear(dim, dim, bias=True)
self.attn1 = UlyssesAttention_VSA(
num_heads=num_heads,
head_size=dim // num_heads,
causal=False,
supported_attention_backends=supported_attention_backends,
prefix=f"{prefix}.attn1",
)
self.hidden_dim = dim
self.num_attention_heads = num_heads
dim_head = dim // num_heads
if qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
elif qk_norm == "rms_norm_across_heads":
# LTX applies qk norm across all heads
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
else:
logger.error("QK Norm type not supported")
raise Exception
assert cross_attn_norm is True
self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=True,
dtype=torch.float32,
compute_dtype=torch.float32,
)
if AttentionBackendEnum.VIDEO_SPARSE_ATTN in supported_attention_backends:
supported_attention_backends.remove(AttentionBackendEnum.VIDEO_SPARSE_ATTN)
# 2. Cross-attention
if added_kv_proj_dim is not None:
# I2V
self.attn2 = WanI2VCrossAttention(
dim,
num_heads,
qk_norm=qk_norm,
eps=eps,
supported_attention_backends=supported_attention_backends,
)
else:
# T2V
self.attn2 = WanT2VCrossAttention(
dim,
num_heads,
qk_norm=qk_norm,
eps=eps,
supported_attention_backends=supported_attention_backends,
)
self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift(
dim,
norm_type="layer",
eps=eps,
elementwise_affine=False,
dtype=torch.float32,
compute_dtype=torch.float32,
)
# 3. Feed-forward
self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh")
self.mlp_residual = ScaleResidual()
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
freqs_cis: tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
if hidden_states.dim() == 4:
hidden_states = hidden_states.squeeze(1)
bs, seq_length, _ = hidden_states.shape
orig_dtype = hidden_states.dtype
# assert orig_dtype != torch.float32
e = self.scale_shift_table + temb.float()
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk(
6, dim=1
)
assert shift_msa.dtype == torch.float32
# 1. Self-attention
norm_hidden_states = (
self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa
).to(orig_dtype)
query, _ = self.to_q(norm_hidden_states)
key, _ = self.to_k(norm_hidden_states)
value, _ = self.to_v(norm_hidden_states)
gate_compress, _ = self.to_gate_compress(norm_hidden_states)
if self.norm_q is not None:
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k(key)
query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1))
gate_compress = gate_compress.squeeze(1).unflatten(
2, (self.num_attention_heads, -1)
)
# Apply rotary embeddings
cos, sin = freqs_cis
query, key = _apply_rotary_emb(
query, cos, sin, is_neox_style=False
), _apply_rotary_emb(key, cos, sin, is_neox_style=False)
attn_output, _ = self.attn1(query, key, value, gate_compress=gate_compress)
attn_output = attn_output.flatten(2)
attn_output, _ = self.to_out(attn_output)
attn_output = attn_output.squeeze(1)
null_shift = null_scale = torch.zeros((1,), device=hidden_states.device)
norm_hidden_states, hidden_states = self.self_attn_residual_norm(
hidden_states, attn_output, gate_msa, null_shift, null_scale
)
norm_hidden_states, hidden_states = norm_hidden_states.to(
orig_dtype
), hidden_states.to(orig_dtype)
# 2. Cross-attention
attn_output = self.attn2(
norm_hidden_states, context=encoder_hidden_states, context_lens=None
)
norm_hidden_states, hidden_states = self.cross_attn_residual_norm(
hidden_states, attn_output, 1, c_shift_msa, c_scale_msa
)
norm_hidden_states, hidden_states = norm_hidden_states.to(
orig_dtype
), hidden_states.to(orig_dtype)
# 3. Feed-forward
ff_output = self.ffn(norm_hidden_states)
hidden_states = self.mlp_residual(hidden_states, ff_output, c_gate_msa)
hidden_states = hidden_states.to(orig_dtype)
return hidden_states
class WanTransformer3DModel(CachableDiT):
_fsdp_shard_conditions = WanVideoConfig()._fsdp_shard_conditions
_compile_conditions = WanVideoConfig()._compile_conditions
_supported_attention_backends = WanVideoConfig()._supported_attention_backends
param_names_mapping = WanVideoConfig().param_names_mapping
reverse_param_names_mapping = WanVideoConfig().reverse_param_names_mapping
lora_param_names_mapping = WanVideoConfig().lora_param_names_mapping
def __init__(self, config: WanVideoConfig, hf_config: dict[str, Any]) -> None:
super().__init__(config=config, hf_config=hf_config)
inner_dim = config.num_attention_heads * config.attention_head_dim
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.in_channels = config.in_channels
self.out_channels = config.out_channels
self.num_channels_latents = config.num_channels_latents
self.patch_size = config.patch_size
self.text_len = config.text_len
# 1. Patch & position embedding
self.patch_embedding = PatchEmbed(
in_chans=config.in_channels,
embed_dim=inner_dim,
patch_size=config.patch_size,
flatten=False,
)
# 2. Condition embeddings
self.condition_embedder = WanTimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=config.freq_dim,
text_embed_dim=config.text_dim,
image_embed_dim=config.image_dim,
)
# 3. Transformer blocks
attn_backend = get_global_server_args().attention_backend
transformer_block = (
WanTransformerBlock_VSA
if (attn_backend and attn_backend.lower() == "video_sparse_attn")
else WanTransformerBlock
)
self.blocks = nn.ModuleList(
[
transformer_block(
inner_dim,
config.ffn_dim,
config.num_attention_heads,
config.qk_norm,
config.cross_attn_norm,
config.eps,
config.added_kv_proj_dim,
self._supported_attention_backends
| {AttentionBackendEnum.VIDEO_SPARSE_ATTN},
prefix=f"{config.prefix}.blocks.{i}",
)
for i in range(config.num_layers)
]
)
# 4. Output norm & projection
self.norm_out = LayerNormScaleShift(
inner_dim,
norm_type="layer",
eps=config.eps,
elementwise_affine=False,
dtype=torch.float32,
compute_dtype=torch.float32,
)
self.proj_out = nn.Linear(
inner_dim, config.out_channels * math.prod(config.patch_size)
)
self.scale_shift_table = nn.Parameter(
torch.randn(1, 2, inner_dim) / inner_dim**0.5
)
# For type checking
self.previous_e0_even = None
self.previous_e0_odd = None
self.previous_residual_even = None
self.previous_residual_odd = None
self.is_even = True
self.should_calc_even = True
self.should_calc_odd = True
self.accumulated_rel_l1_distance_even = 0
self.accumulated_rel_l1_distance_odd = 0
self.cnt = 0
self.__post_init__()
# misc
self.sp_size = get_sp_world_size()
# Get rotary embeddings
d = self.hidden_size // self.num_attention_heads
self.rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]
self.rope = NDRotaryEmbedding(
rope_dim_list=self.rope_dim_list,
rope_theta=10000,
dtype=torch.float32 if current_platform.is_mps() else torch.float64,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | list[torch.Tensor],
timestep: torch.LongTensor,
encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None,
guidance=None,
**kwargs,
) -> torch.Tensor:
forward_batch = get_forward_context().forward_batch
enable_teacache = forward_batch is not None and forward_batch.enable_teacache
orig_dtype = hidden_states.dtype
if not isinstance(encoder_hidden_states, torch.Tensor):
encoder_hidden_states = encoder_hidden_states[0]
if (
isinstance(encoder_hidden_states_image, list)
and len(encoder_hidden_states_image) > 0
):
encoder_hidden_states_image = encoder_hidden_states_image[0]
else:
encoder_hidden_states_image = None
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
freqs_cos, freqs_sin = self.rope.forward_from_grid(
(
post_patch_num_frames * self.sp_size,
post_patch_height,
post_patch_width,
),
shard_dim=0,
start_frame=0,
device=hidden_states.device,
)
assert freqs_cos.dtype == torch.float32
assert freqs_cos.device == hidden_states.device
freqs_cis = (
(freqs_cos.float(), freqs_sin.float()) if freqs_cos is not None else None
)
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
if timestep.dim() == 2:
ts_seq_len = timestep.shape[1]
timestep = timestep.flatten() # batch_size * seq_len
else:
ts_seq_len = None
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = (
self.condition_embedder(
timestep,
encoder_hidden_states,
encoder_hidden_states_image,
timestep_seq_len=ts_seq_len,
)
)
if ts_seq_len is not None:
# batch_size, seq_len, 6, inner_dim
timestep_proj = timestep_proj.unflatten(2, (6, -1))
else:
# batch_size, 6, inner_dim
timestep_proj = timestep_proj.unflatten(1, (6, -1))
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat(
[encoder_hidden_states_image, encoder_hidden_states], dim=1
)
encoder_hidden_states = (
encoder_hidden_states.to(orig_dtype)
if current_platform.is_mps()
else encoder_hidden_states
) # cast to orig_dtype for MPS
assert encoder_hidden_states.dtype == orig_dtype
# 4. Transformer blocks
# if caching is enabled, we might be able to skip the forward pass
should_skip_forward = self.should_skip_forward_for_cached_states(
timestep_proj=timestep_proj, temb=temb
)
if should_skip_forward:
hidden_states = self.retrieve_cached_states(hidden_states)
else:
# if teacache is enabled, we need to cache the original hidden states
if enable_teacache:
original_hidden_states = hidden_states.clone()
for block in self.blocks:
hidden_states = block(
hidden_states, encoder_hidden_states, timestep_proj, freqs_cis
)
# if teacache is enabled, we need to cache the original hidden states
if enable_teacache:
self.maybe_cache_states(hidden_states, original_hidden_states)
# 5. Output norm, projection & unpatchify
if temb.dim() == 3:
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
shift, scale = (
self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)
).chunk(2, dim=2)
shift = shift.squeeze(2)
scale = scale.squeeze(2)
else:
# batch_size, inner_dim
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states, shift, scale)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size,
post_patch_num_frames,
post_patch_height,
post_patch_width,
p_t,
p_h,
p_w,
-1,
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
return output
def maybe_cache_states(
self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor
) -> None:
if self.is_even:
self.previous_residual_even = (
hidden_states.squeeze(0) - original_hidden_states
)
else:
self.previous_residual_odd = (
hidden_states.squeeze(0) - original_hidden_states
)
def should_skip_forward_for_cached_states(self, **kwargs) -> bool:
forward_context = get_forward_context()
forward_batch = forward_context.forward_batch
if forward_batch is None or not forward_batch.enable_teacache:
return False
teacache_params = forward_batch.teacache_params
assert teacache_params is not None, "teacache_params is not initialized"
assert isinstance(
teacache_params, WanTeaCacheParams
), "teacache_params is not a WanTeaCacheParams"
current_timestep = forward_context.current_timestep
num_inference_steps = forward_batch.num_inference_steps
# initialize the coefficients, cutoff_steps, and ret_steps
coefficients = teacache_params.coefficients
use_ret_steps = teacache_params.use_ret_steps
cutoff_steps = teacache_params.get_cutoff_steps(num_inference_steps)
ret_steps = teacache_params.ret_steps
teacache_thresh = teacache_params.teacache_thresh
if current_timestep == 0:
self.cnt = 0
timestep_proj = kwargs["timestep_proj"]
temb = kwargs["temb"]
modulated_inp = timestep_proj if use_ret_steps else temb
if self.cnt % 2 == 0: # even -> condition
self.is_even = True
if self.cnt < ret_steps or self.cnt >= cutoff_steps:
self.should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
else:
assert (
self.previous_e0_even is not None
), "previous_e0_even is not initialized"
assert (
self.accumulated_rel_l1_distance_even is not None
), "accumulated_rel_l1_distance_even is not initialized"
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance_even += rescale_func(
(
(modulated_inp - self.previous_e0_even).abs().mean()
/ self.previous_e0_even.abs().mean()
)
.cpu()
.item()
)
if self.accumulated_rel_l1_distance_even < teacache_thresh:
self.should_calc_even = False
else:
self.should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = modulated_inp.clone()
else: # odd -> unconditon
self.is_even = False
if self.cnt < ret_steps or self.cnt >= cutoff_steps:
self.should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
else:
assert (
self.previous_e0_odd is not None
), "previous_e0_odd is not initialized"
assert (
self.accumulated_rel_l1_distance_odd is not None
), "accumulated_rel_l1_distance_odd is not initialized"
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance_odd += rescale_func(
(
(modulated_inp - self.previous_e0_odd).abs().mean()
/ self.previous_e0_odd.abs().mean()
)
.cpu()
.item()
)
if self.accumulated_rel_l1_distance_odd < teacache_thresh:
self.should_calc_odd = False
else:
self.should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = modulated_inp.clone()
self.cnt += 1
should_skip_forward = False
if self.is_even:
if not self.should_calc_even:
should_skip_forward = True
else:
if not self.should_calc_odd:
should_skip_forward = True
return should_skip_forward
def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.is_even:
return hidden_states + self.previous_residual_even
else:
return hidden_states + self.previous_residual_odd
EntryClass = WanTransformer3DModel
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from dataclasses import field
import torch
from torch import nn
from sglang.multimodal_gen.configs.models.encoders import (
BaseEncoderOutput,
ImageEncoderConfig,
TextEncoderConfig,
)
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
class TextEncoder(nn.Module, ABC):
_fsdp_shard_conditions: list = field(default_factory=lambda: [])
_stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list)
_supported_attention_backends: set[AttentionBackendEnum] = (
TextEncoderConfig()._supported_attention_backends
)
def __init__(self, config: TextEncoderConfig) -> None:
super().__init__()
self.config = config
self._fsdp_shard_conditions = config._fsdp_shard_conditions
self._stacked_params_mapping = config.arch_config.stacked_params_mapping
if not self.supported_attention_backends:
raise ValueError(
f"Subclass {self.__class__.__name__} must define _supported_attention_backends"
)
@abstractmethod
def forward(
self,
input_ids: torch.Tensor | None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
output_hidden_states: bool | None = None,
**kwargs,
) -> BaseEncoderOutput:
pass
@property
def supported_attention_backends(self) -> set[AttentionBackendEnum]:
return self._supported_attention_backends
class ImageEncoder(nn.Module, ABC):
_supported_attention_backends: set[AttentionBackendEnum] = (
ImageEncoderConfig()._supported_attention_backends
)
def __init__(self, config: ImageEncoderConfig) -> None:
super().__init__()
self.config = config
if not self.supported_attention_backends:
raise ValueError(
f"Subclass {self.__class__.__name__} must define _supported_attention_backends"
)
@abstractmethod
def forward(self, pixel_values: torch.Tensor, **kwargs) -> BaseEncoderOutput:
pass
@property
def supported_attention_backends(self) -> set[AttentionBackendEnum]:
return self._supported_attention_backends
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# type: ignore
import os
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
class HunyuanClip(nn.Module):
"""
Hunyuan clip code copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
hunyuan's clip used BertModel and BertTokenizer, so we copy it.
"""
def __init__(self, model_dir, max_length=77):
super().__init__()
self.max_length = max_length
self.tokenizer = BertTokenizer.from_pretrained(
os.path.join(model_dir, "tokenizer")
)
self.text_encoder = BertModel.from_pretrained(
os.path.join(model_dir, "clip_text_encoder")
)
@torch.no_grad
def forward(self, prompts, with_mask=True):
self.device = next(self.text_encoder.parameters()).device
text_inputs = self.tokenizer(
prompts,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
prompt_embeds = self.text_encoder(
text_inputs.input_ids.to(self.device),
attention_mask=(
text_inputs.attention_mask.to(self.device) if with_mask else None
),
)
return prompt_embeds.last_hidden_state, prompt_embeds.pooler_output
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/clip.py
# Adapted from transformers: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
from sglang.multimodal_gen.configs.models.encoders import (
BaseEncoderOutput,
CLIPTextConfig,
CLIPVisionConfig,
)
from sglang.multimodal_gen.runtime.distributed import divide, get_tp_world_size
from sglang.multimodal_gen.runtime.layers.activation import get_act_fn
from sglang.multimodal_gen.runtime.layers.attention import LocalAttention
from sglang.multimodal_gen.runtime.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig
# TODO: support quantization
# from vllm.model_executor.layers.quantization import QuantizationConfig
from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader
from sglang.multimodal_gen.runtime.models.encoders.base import ImageEncoder, TextEncoder
from sglang.multimodal_gen.runtime.models.encoders.vision import (
resolve_visual_encoder_outputs,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
assert self.image_size % self.patch_size == 0
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False,
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype)
) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class CLIPTextEmbeddings(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(
config.max_position_embeddings, embed_dim
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
persistent=False,
)
def forward(
self,
input_ids: torch.LongTensor | None = None,
position_ids: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
) -> torch.Tensor:
if input_ids is not None:
seq_length = input_ids.shape[-1]
elif inputs_embeds is not None:
seq_length = inputs_embeds.shape[-2]
else:
raise ValueError("Either input_ids or inputs_embeds must be provided.")
max_position_embedding = self.position_embedding.weight.shape[0]
if seq_length > max_position_embedding:
raise ValueError(
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: CLIPVisionConfig | CLIPTextConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.tp_size = get_tp_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = LocalAttention(
self.num_heads_per_partition,
self.head_dim,
self.num_heads_per_partition,
softmax_scale=self.scale,
causal=False,
supported_attention_backends=config._supported_attention_backends,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
):
"""Input shape: Batch x Time x Channel"""
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
# use flash_attn_func
query_states = query_states.reshape(
query_states.shape[0],
query_states.shape[1],
self.num_heads_per_partition,
self.head_dim,
)
key_states = key_states.reshape(
key_states.shape[0],
key_states.shape[1],
self.num_heads_per_partition,
self.head_dim,
)
value_states = value_states.reshape(
value_states.shape[0],
value_states.shape[1],
self.num_heads_per_partition,
self.head_dim,
)
attn_output = self.attn(query_states, key_states, value_states)
attn_output = attn_output.reshape(
attn_output.shape[0],
attn_output.shape[1],
self.num_heads_per_partition * self.head_dim,
)
attn_output, _ = self.out_proj(attn_output)
return attn_output, None
class CLIPMLP(nn.Module):
def __init__(
self,
config: CLIPVisionConfig | CLIPTextConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class CLIPEncoderLayer(nn.Module):
def __init__(
self,
config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.self_attn = CLIPAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp")
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class CLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def __init__(
self,
config: CLIPVisionConfig | CLIPTextConfig,
quant_config: QuantizationConfig | None = None,
num_hidden_layers_override: int | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList(
[
CLIPEncoderLayer(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
for layer_idx in range(num_hidden_layers)
]
)
def forward(
self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
) -> torch.Tensor | list[torch.Tensor]:
hidden_states_pool = [inputs_embeds]
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(hidden_states)
if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return [hidden_states]
class CLIPTextTransformer(nn.Module):
def __init__(
self,
config: CLIPTextConfig,
quant_config: QuantizationConfig | None = None,
num_hidden_layers_override: int | None = None,
prefix: str = "",
):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=prefix,
)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
# For `pooled_output` computation
self.eos_token_id = config.eos_token_id
def forward(
self,
input_ids: torch.Tensor | None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
output_hidden_states: bool | None = None,
) -> BaseEncoderOutput:
r"""
Returns:
"""
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
if input_ids is None:
raise ValueError("You have to specify input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
# causal_attention_mask = _create_4d_causal_attention_mask(
# input_shape, hidden_states.dtype, device=hidden_states.device
# )
# # expand attention_mask
# if attention_mask is not None and not self._use_flash_attention_2:
# raise NotImplementedError("attention_mask is not supported for CLIPTextTransformer")
# # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
# attention_mask=attention_mask,
# causal_attention_mask=causal_attention_mask,
# output_attentions=output_attentions,
return_all_hidden_states=output_hidden_states,
# return_dict=return_dict,
)
last_hidden_state = encoder_outputs[-1]
last_hidden_state = self.final_layer_norm(last_hidden_state)
if self.eos_token_id == 2:
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device
),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(
dim=-1
),
]
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[
torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device
),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
# Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
(
input_ids.to(dtype=torch.int, device=last_hidden_state.device)
== self.eos_token_id
)
.int()
.argmax(dim=-1),
]
return BaseEncoderOutput(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs,
# attentions=encoder_outputs.attentions,
)
class CLIPTextModel(TextEncoder):
def __init__(
self,
config: CLIPTextConfig,
) -> None:
super().__init__(config)
self.text_model = CLIPTextTransformer(
config=config, quant_config=config.quant_config, prefix=config.prefix
)
def forward(
self,
input_ids: torch.Tensor | None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
output_hidden_states: bool | None = None,
**kwargs,
) -> BaseEncoderOutput:
outputs: BaseEncoderOutput = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=output_hidden_states,
)
return outputs
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# Define mapping for stacked parameters
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# Handle q_proj, k_proj, v_proj -> qkv_proj mapping
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name in name:
# Replace the weight name with the parameter name
model_param_name = name.replace(weight_name, param_name)
if model_param_name in params_dict:
param = params_dict[model_param_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(model_param_name)
break
else:
# Use default weight loader for all other parameters
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class CLIPVisionTransformer(nn.Module):
def __init__(
self,
config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPVisionEmbeddings(config)
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights.
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
# If possible, skip post_layernorm to conserve memory
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
else:
self.post_layernorm = None
def forward(
self,
pixel_values: torch.Tensor,
output_hidden_states: Optional[bool] = None,
feature_sample_layers: list[int] | None = None,
) -> BaseEncoderOutput:
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
return_all_hidden_states = output_hidden_states or (
feature_sample_layers is not None
)
# Produces either the last layer output or all of the hidden states,
# depending on if we have feature_sample_layers or not
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states,
)
if not return_all_hidden_states:
encoder_outputs = encoder_outputs[0]
# Handle post-norm (if applicable) and stacks feature layers if needed
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs,
feature_sample_layers,
self.post_layernorm,
self.config.num_hidden_layers,
)
if return_all_hidden_states:
return BaseEncoderOutput(hidden_states=encoder_outputs)
return BaseEncoderOutput(last_hidden_state=encoder_outputs)
class CLIPVisionModel(ImageEncoder):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
def __init__(self, config: CLIPVisionConfig) -> None:
super().__init__(config)
self.vision_model = CLIPVisionTransformer(
config=config,
quant_config=config.quant_config,
num_hidden_layers_override=config.num_hidden_layers_override,
require_post_norm=config.require_post_norm,
prefix=f"{config.prefix}.vision_model",
)
def forward(
self,
pixel_values: torch.Tensor,
feature_sample_layers: list[int] | None = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> BaseEncoderOutput:
base_encoder_output = self.vision_model(
pixel_values,
output_hidden_states=output_hidden_states,
feature_sample_layers=feature_sample_layers,
)
return base_encoder_output
@property
def device(self):
return next(self.parameters()).device
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights:
if name.startswith("visual_projection"):
continue
# post_layernorm is not needed in CLIPVisionModel
if (
name.startswith("vision_model.post_layernorm")
and self.vision_model.post_layernorm is None
):
continue
# omit layers when num_hidden_layers_override is set
if name.startswith("vision_model.encoder.layers"):
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue
for (
param_name,
weight_name,
shard_id,
) in self.config.arch_config.stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class BertModel(CLIPTextModel):
pass
EntryClass = [CLIPTextModel, CLIPVisionModel]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/llama.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Any
import torch
from torch import nn
# from ..utils import (extract_layer_index)
from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput, LlamaConfig
from sglang.multimodal_gen.runtime.distributed import get_tp_world_size
from sglang.multimodal_gen.runtime.layers.activation import SiluAndMul
# from vllm.model_executor.layers.quantization import QuantizationConfig
from sglang.multimodal_gen.runtime.layers.attention import LocalAttention
from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm
from sglang.multimodal_gen.runtime.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig
from sglang.multimodal_gen.runtime.layers.rotary_embedding import get_rope
from sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.multimodal_gen.runtime.loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
# output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class LlamaAttention(nn.Module):
def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
bias_o_proj: bool = False,
prefix: str = "",
) -> None:
super().__init__()
# layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
tp_size = get_tp_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(
config, "head_dim", self.hidden_size // self.total_num_heads
)
# Phi models introduced a partial_rotary_factor parameter in the config
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
is_neox_style = True
is_gguf = (
quant_config
and hasattr(quant_config, "get_name")
and quant_config.get_name() == "gguf"
)
if is_gguf and config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
)
self.attn = LocalAttention(
self.num_heads,
self.head_dim,
self.num_kv_heads,
softmax_scale=self.scaling,
causal=True,
supported_attention_backends=config._supported_attention_backends,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
# attn_output = self.attn(q, k, v)
# use flash_attn_func
# TODO (Attn abstraction and backend)
# reshape q, k, v to (batch_size, seq_len, num_heads, head_dim)
batch_size = q.shape[0]
seq_len = q.shape[1]
q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)
v = v.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# import pdb; pdb.set_trace()
# attn_output = flash_attn_varlen_func(q, k, v, softmax_scale=self.scaling, causal=True)
attn_output = self.attn(q, k, v)
attn_output = attn_output.reshape(
batch_size, seq_len, self.num_heads * self.head_dim
)
output, _ = self.o_proj(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings
)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False
)
bias_o_proj = attention_bias
# support internlm/internlm3-8b with qkv_bias
if hasattr(config, "qkv_bias"):
attention_bias = config.qkv_bias
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
bias_o_proj=bias_o_proj,
prefix=f"{prefix}.self_attn",
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(TextEncoder):
def __init__(
self,
config: LlamaConfig,
):
super().__init__(config)
self.config = config
self.quant_config = self.config.quant_config
if config.lora_config is not None:
max_loras = 1
lora_vocab_size = 1
if hasattr(config.lora_config, "max_loras"):
max_loras = config.lora_config.max_loras
if hasattr(config.lora_config, "lora_extra_vocab_size"):
lora_vocab_size = config.lora_config.lora_extra_vocab_size
lora_vocab = lora_vocab_size * max_loras
else:
lora_vocab = 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=config.quant_config,
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(
config=config,
quant_config=config.quant_config,
prefix=f"{config.prefix}.layers.{i}",
)
for i in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
output_hidden_states: bool | None = None,
**kwargs,
) -> BaseEncoderOutput:
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
if position_ids is None:
position_ids = torch.arange(
0, hidden_states.shape[1], device=hidden_states.device
).unsqueeze(0)
all_hidden_states: tuple[Any, ...] | None = () if output_hidden_states else None
for layer in self.layers:
if all_hidden_states is not None:
# TODO
all_hidden_states += (
(hidden_states,)
if residual is None
else (hidden_states + residual,)
)
hidden_states, residual = layer(position_ids, hidden_states, residual)
hidden_states, _ = self.norm(hidden_states, residual)
# add hidden states from the last decoder layer
if all_hidden_states is not None:
all_hidden_states += (hidden_states,)
# TODO(will): maybe unify the output format with other models and use
# our own class
output = BaseEncoderOutput(
last_hidden_state=hidden_states,
# past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
# attentions=all_self_attns,
)
return output
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# if (self.quant_config is not None and
# (scale_name := self.quant_config.get_cache_scale(name))):
# # Loading kv cache quantization scales
# param = params_dict[scale_name]
# weight_loader = getattr(param, "weight_loader",
# default_weight_loader)
# loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
# loaded_weight[0])
# weight_loader(param, loaded_weight)
# loaded_params.add(scale_name)
# continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
kv_scale_name: str | None = maybe_remap_kv_scale_name(name, params_dict)
if kv_scale_name is None:
continue
else:
name = kv_scale_name
for (
param_name,
weight_name,
shard_id,
) in self.config.arch_config.stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
EntryClass = LlamaModel
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from types import SimpleNamespace
from transformers import (
Cache,
DynamicCache,
PretrainedConfig,
Qwen2_5_VLTextConfig,
Qwen2RMSNorm,
)
from transformers.masking_utils import (
create_causal_mask,
create_sliding_window_causal_mask,
)
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import TransformersKwargs, is_torchdynamo_compiling
from sglang.multimodal_gen.configs.models.encoders.qwen_image import Qwen2_5VLConfig
from sglang.multimodal_gen.runtime.layers.attention import LocalAttention
from sglang.multimodal_gen.runtime.layers.linear import (
MergedColumnParallelLinear,
RowParallelLinear,
)
from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig
from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader
from sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.utils.common import add_prefix
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import logging
from typing import Callable, Iterable, Optional, Tuple, Union, Unpack
import torch
import torch.nn as nn
from transformers.activations import ACT2FN
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLAttention,
Qwen2_5_VLCausalLMOutputWithPast,
Qwen2_5_VLModelOutputWithPast,
Qwen2_5_VLRotaryEmbedding,
Qwen2MLP,
apply_multimodal_rotary_pos_emb,
eager_attention_forward,
)
logger = logging.getLogger(__name__)
class Qwen2_5_VLAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
"""
def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warn(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.is_causal = True
self.attention_dropout = config.attention_dropout
self.rope_scaling = config.rope_scaling
self.scaling = self.head_dim**-0.5
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=True
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
self.sliding_window = (
config.sliding_window
if config.layer_types[layer_idx] == "sliding_attention"
else None
)
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)
self.attn = LocalAttention(
num_heads=self.num_heads,
head_size=self.head_dim,
num_kv_heads=self.num_key_value_heads,
softmax_scale=self.scaling,
causal=True,
supported_attention_backends=(
AttentionBackendEnum.FA3,
AttentionBackendEnum.TORCH_SDPA,
),
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
tuple[torch.Tensor, torch.Tensor]
] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
if past_key_values is not None:
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
} # Specific to RoPE models
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward
# if self.config._attn_implementation != "eager":
# attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = self.attn(query_states, key_states, value_states)
#
# attn_output, attn_weights = attention_interface(
# self,
# query_states,
# key_states,
# value_states,
# attention_mask,
# dropout=0.0 if not self.training else self.attention_dropout,
# scaling=self.scaling,
# sliding_window=self.sliding_window,
# position_ids=position_ids, # pass positions for FA2
# **kwargs,
# )
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class Qwen2_5_VLDecoderLayer(nn.Module):
def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
if (
config.use_sliding_window
and config._attn_implementation != "flash_attention_2"
):
logger.warning(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
)
self.self_attn = Qwen2_5_VLAttention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.attention_type = config.layer_types[layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
tuple[torch.Tensor, torch.Tensor]
] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[
torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Qwen2_5_VLMLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int = None,
bias: bool = True,
hidden_act="silu",
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
)
self.act = ACT2FN[hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
gate, up = gate_up.chunk(2, dim=-1)
x = self.act(gate) * up
x_down, _ = self.down_proj(x)
return x_down
class Qwen2_5_VLTextModel(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
Qwen2_5_VLDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self._attn_implementation = config._attn_implementation
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
self.gradient_checkpointing = False
# Initialize weights and apply final processing
# self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache(config=self.config)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# the hard coded `3` is for temporal, height and width.
if position_ids is None:
position_ids = cache_position.view(1, 1, -1).expand(
3, inputs_embeds.shape[0], -1
)
elif position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
# NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
# where each dim indicates visual spatial positions for temporal/height/width grids.
# There are two scenarios when FA2-like packed masking might be activated.
# 1. User specifically passed packed `position_ids` and no attention mask.
# In this case we expect the user to create correct position ids for all 3 grids
# and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
# 2. User runs forward with no attention mask and no position ids. In this case, position ids
# are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
# prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
# text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
text_position_ids = position_ids[0]
position_ids = position_ids[1:]
else:
text_position_ids = position_ids[0]
# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": text_position_ids,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
}
# The sliding window alternating layers are not always activated depending on the config
if self.has_sliding_layers:
causal_mask_mapping["sliding_attention"] = (
create_sliding_window_causal_mask(**mask_kwargs)
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=text_position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
past_key_values,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class Qwen2_5_VLModel(nn.Module):
base_model_prefix = ""
_checkpoint_conversion_mapping = {"^model": "language_model"}
# Reference: fix gemma3 grad acc #37208
accepts_loss_kwargs = False
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
def __init__(self, config):
super().__init__()
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(
config.vision_config
)
self.language_model = Qwen2_5_VLTextModel(config.text_config)
self.visual.to(torch.get_default_dtype())
self.rope_deltas = None # cache rope_deltas here
self.config = config
# Initialize weights and apply final processing
# self.post_init()
def get_input_embeddings(self):
return self.language_model.embed_tokens
def set_input_embeddings(self, value):
self.language_model.embed_tokens = value
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_rope_index(
self,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
Explanation:
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
Examples:
input_ids: [T T T T T], here T is for text.
temporal position_ids: [0, 1, 2, 3, 4]
height position_ids: [0, 1, 2, 3, 4]
width position_ids: [0, 1, 2, 3, 4]
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
and 1D rotary position embedding for text part.
Examples:
Temporal (Time): 3 patches, representing different segments of the video in time.
Height: 2 patches, dividing each frame vertically.
Width: 2 patches, dividing each frame horizontally.
We also have some important parameters:
fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
text temporal position_ids: [101, 102, 103, 104, 105]
text height position_ids: [101, 102, 103, 104, 105]
text width position_ids: [101, 102, 103, 104, 105]
Here we calculate the text start position_ids as the max vision position_ids plus 1.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
Returns:
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
"""
spatial_merge_size = self.config.vision_config.spatial_merge_size
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
mrope_position_deltas = []
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(
input_ids == vision_start_token_id
).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
second_per_grid_t = 0
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
## normalize type, send to device.
second_per_grid_t = torch.as_tensor(
second_per_grid_t,
dtype=range_tensor.dtype,
device=range_tensor.device,
)
time_tensor = (
expanded_range
* second_per_grid_t
* self.config.vision_config.tokens_per_second
)
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
position_ids.device
)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(total_input_ids[i])
)
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = (
position_ids.unsqueeze(0)
.expand(3, -1, -1)
.to(attention_mask.device)
)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
-1, keepdim=True
)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
def get_video_features(
self,
pixel_values_videos: torch.FloatTensor,
video_grid_thw: Optional[torch.LongTensor] = None,
):
"""
Encodes videos into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input videos.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
split_sizes = (
video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
).tolist()
video_embeds = torch.split(video_embeds, split_sizes)
return video_embeds
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_grid_thw: Optional[torch.LongTensor] = None,
):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
split_sizes = (
image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2
).tolist()
image_embeds = torch.split(image_embeds, split_sizes)
return image_embeds
def get_placeholder_mask(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.FloatTensor,
image_features: torch.FloatTensor = None,
video_features: torch.FloatTensor = None,
):
"""
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
equal to the length of multimodal features. If the lengths are different, an error is raised.
"""
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(
self.config.image_token_id,
dtype=torch.long,
device=inputs_embeds.device,
)
)
special_image_mask = special_image_mask.all(-1)
special_video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(
self.config.video_token_id,
dtype=torch.long,
device=inputs_embeds.device,
)
)
special_video_mask = special_video_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_video_mask = input_ids == self.config.video_token_id
n_image_tokens = special_image_mask.sum()
special_image_mask = (
special_image_mask.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
if (
image_features is not None
and inputs_embeds[special_image_mask].numel() != image_features.numel()
):
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
)
n_video_tokens = special_video_mask.sum()
special_video_mask = (
special_video_mask.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
if (
video_features is not None
and inputs_embeds[special_video_mask].numel() != video_features.numel()
):
raise ValueError(
f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
)
return special_image_mask, special_video_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Qwen2_5_VLModelOutputWithPast]:
r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = torch.cat(image_embeds, dim=0).to(
inputs_embeds.device, inputs_embeds.dtype
)
image_mask, _ = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_embeds = torch.cat(video_embeds, dim=0).to(
inputs_embeds.device, inputs_embeds.dtype
)
_, video_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if position_ids is None:
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do asssisted decoding
prefill_compiled_stage = is_torchdynamo_compiling() and (
(input_ids is not None and input_ids.shape[1] != 1)
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
)
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
(cache_position is not None and cache_position[0] == 0)
or (past_key_values is None or past_key_values.get_seq_length() == 0)
)
if (
prefill_compiled_stage or prefill_noncompiled_stage
) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask,
)
self.rope_deltas = rope_deltas
else:
batch_size, seq_length, _ = inputs_embeds.shape
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
if cache_position is not None:
delta = (cache_position[0] + self.rope_deltas).to(
inputs_embeds.device
)
else:
delta = torch.zeros(
(batch_size, seq_length), device=inputs_embeds.device
)
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
position_ids += delta.to(position_ids.device)
outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
output = Qwen2_5_VLModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
return output if return_dict else output.to_tuple()
class DotDict(dict):
def __init__(self, mapping):
super().__init__()
for key, value in mapping.items():
if isinstance(value, dict):
value = DotDict(value) # 递归转换
elif isinstance(value, list):
# 如果是 list,且元素是 dict 也递归转换
value = [
DotDict(item) if isinstance(item, dict) else item for item in value
]
self[key] = value
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"No attribute '{item}'")
def __setattr__(self, key, value):
self[key] = value
def __delattr__(self, key):
del self[key]
def dict_to_namespace(d):
for k, v in d.items():
if isinstance(v, dict):
d[k] = dict_to_namespace(v)
elif isinstance(v, list):
d[k] = [dict_to_namespace(i) if isinstance(i, dict) else i for i in v]
return SimpleNamespace(**d)
class Qwen2_5_VLForConditionalGeneration(TextEncoder):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_up_proj.",
".down_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
config: Qwen2_5VLConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config)
config = config.arch_config
self.model = Qwen2_5_VLModel(config)
self.lm_head = nn.Linear(
config.text_config.hidden_size, config.text_config.vocab_size, bias=False
)
self.config = config
def get_input_embeddings(self):
return self.model.embed_tokens
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
):
"""Run forward pass for Qwen2_5-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
output_attentions = False
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
return Qwen2_5_VLCausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=outputs.rope_deltas,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loaded_params: set[str] = set()
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
name = name.replace("model.", "model.language_model.")
if "visual." in name:
name = name.replace("visual.", "model.visual.")
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
except KeyError:
print(params_dict.keys())
raise
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight.to(param.dtype)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
EntryClass = Qwen2_5_VLForConditionalGeneration
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# type: ignore
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import os
from functools import wraps
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from sglang.multimodal_gen.runtime.models.dits.stepvideo import StepVideoRMSNorm
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
def __init__(self, device=None):
self.device = device
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if getattr(func, "__module__", None) == "torch.nn.init":
if "tensor" in kwargs:
return kwargs["tensor"]
else:
return args[0]
if (
self.device is not None
and func in torch.utils._device._device_constructors()
and kwargs.get("device") is None
):
kwargs["device"] = self.device
return func(*args, **kwargs)
def with_empty_init(func):
@wraps(func)
def wrapper(*args, **kwargs):
with EmptyInitOnDevice("cpu"):
return func(*args, **kwargs)
return wrapper
class LLaMaEmbedding(nn.Module):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self,
cfg,
):
super().__init__()
self.hidden_size = cfg.hidden_size
self.params_dtype = cfg.params_dtype
self.fp32_residual_connection = cfg.fp32_residual_connection
self.embedding_weights_in_fp32 = cfg.embedding_weights_in_fp32
self.word_embeddings = torch.nn.Embedding(
cfg.padded_vocab_size,
self.hidden_size,
)
self.embedding_dropout = torch.nn.Dropout(cfg.hidden_dropout)
def forward(self, input_ids):
# Embeddings.
if self.embedding_weights_in_fp32:
self.word_embeddings = self.word_embeddings.to(torch.float32)
embeddings = self.word_embeddings(input_ids)
if self.embedding_weights_in_fp32:
embeddings = embeddings.to(self.params_dtype)
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
# Data format change to avoid explicit transposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings
class StepChatTokenizer:
"""Step Chat Tokenizer"""
def __init__(
self,
model_file,
name="StepChatTokenizer",
bot_token="<|BOT|>", # Begin of Turn
eot_token="<|EOT|>", # End of Turn
call_start_token="<|CALL_START|>", # Call Start
call_end_token="<|CALL_END|>", # Call End
think_start_token="<|THINK_START|>", # Think Start
think_end_token="<|THINK_END|>", # Think End
mask_start_token="<|MASK_1e69f|>", # Mask start
mask_end_token="<|UNMASK_1e69f|>", # Mask end
):
import sentencepiece
self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
self._vocab = {}
self._inv_vocab = {}
self._special_tokens = {}
self._inv_special_tokens = {}
self._t5_tokens = []
for idx in range(self._tokenizer.get_piece_size()):
text = self._tokenizer.id_to_piece(idx)
self._inv_vocab[idx] = text
self._vocab[text] = idx
if self._tokenizer.is_control(idx) or self._tokenizer.is_unknown(idx):
self._special_tokens[text] = idx
self._inv_special_tokens[idx] = text
self._unk_id = self._tokenizer.unk_id()
self._bos_id = self._tokenizer.bos_id()
self._eos_id = self._tokenizer.eos_id()
for token in [
bot_token,
eot_token,
call_start_token,
call_end_token,
think_start_token,
think_end_token,
]:
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
assert (
token in self._special_tokens
), f"Token '{token}' is not a special token"
for token in [mask_start_token, mask_end_token]:
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
self._bot_id = self._tokenizer.piece_to_id(bot_token)
self._eot_id = self._tokenizer.piece_to_id(eot_token)
self._call_start_id = self._tokenizer.piece_to_id(call_start_token)
self._call_end_id = self._tokenizer.piece_to_id(call_end_token)
self._think_start_id = self._tokenizer.piece_to_id(think_start_token)
self._think_end_id = self._tokenizer.piece_to_id(think_end_token)
self._mask_start_id = self._tokenizer.piece_to_id(mask_start_token)
self._mask_end_id = self._tokenizer.piece_to_id(mask_end_token)
self._underline_id = self._tokenizer.piece_to_id("\u2581")
@property
def vocab(self):
return self._vocab
@property
def inv_vocab(self):
return self._inv_vocab
@property
def vocab_size(self):
return self._tokenizer.vocab_size()
def tokenize(self, text: str) -> list[int]:
return self._tokenizer.encode_as_ids(text)
def detokenize(self, token_ids: list[int]) -> str:
return self._tokenizer.decode_ids(token_ids)
class Tokens:
def __init__(
self, input_ids, cu_input_ids, attention_mask, cu_seqlens, max_seq_len
) -> None:
self.input_ids = input_ids
self.attention_mask = attention_mask
self.cu_input_ids = cu_input_ids
self.cu_seqlens = cu_seqlens
self.max_seq_len = max_seq_len
def to(self, device):
self.input_ids = self.input_ids.to(device)
self.attention_mask = self.attention_mask.to(device)
self.cu_input_ids = self.cu_input_ids.to(device)
self.cu_seqlens = self.cu_seqlens.to(device)
return self
class Wrapped_StepChatTokenizer(StepChatTokenizer):
def __call__(
self,
text,
max_length=320,
padding="max_length",
truncation=True,
return_tensors="pt",
):
# [bos, ..., eos, pad, pad, ..., pad]
self.BOS = 1
self.EOS = 2
self.PAD = 2
out_tokens = []
attn_mask = []
if len(text) == 0:
part_tokens = [self.BOS] + [self.EOS]
valid_size = len(part_tokens)
if len(part_tokens) < max_length:
part_tokens += [self.PAD] * (max_length - valid_size)
out_tokens.append(part_tokens)
attn_mask.append([1] * valid_size + [0] * (max_length - valid_size))
else:
for part in text:
part_tokens = self.tokenize(part)
part_tokens = part_tokens[
: (max_length - 2)
] # leave 2 space for bos and eos
part_tokens = [self.BOS] + part_tokens + [self.EOS]
valid_size = len(part_tokens)
if len(part_tokens) < max_length:
part_tokens += [self.PAD] * (max_length - valid_size)
out_tokens.append(part_tokens)
attn_mask.append([1] * valid_size + [0] * (max_length - valid_size))
out_tokens = torch.tensor(out_tokens, dtype=torch.long)
attn_mask = torch.tensor(attn_mask, dtype=torch.long)
# padding y based on tp size
padded_len = 0
padded_flag = False
if padded_len > 0:
padded_flag = True
if padded_flag:
pad_tokens = torch.tensor(
[[self.PAD] * max_length], device=out_tokens.device
)
pad_attn_mask = torch.tensor(
[[1] * padded_len + [0] * (max_length - padded_len)],
device=attn_mask.device,
)
out_tokens = torch.cat([out_tokens, pad_tokens], dim=0)
attn_mask = torch.cat([attn_mask, pad_attn_mask], dim=0)
# cu_seqlens
cu_out_tokens = out_tokens.masked_select(attn_mask != 0).unsqueeze(0)
seqlen = attn_mask.sum(dim=1).tolist()
cu_seqlens = torch.cumsum(torch.tensor([0] + seqlen), 0).to(
device=out_tokens.device, dtype=torch.int32
)
max_seq_len = max(seqlen)
return Tokens(out_tokens, cu_out_tokens, attn_mask, cu_seqlens, max_seq_len)
def flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=True,
return_attn_probs=False,
tp_group_rank=0,
tp_group_size=1,
):
softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
return torch.ops.Optimus.fwd(
q,
k,
v,
None,
dropout_p,
softmax_scale,
causal,
return_attn_probs,
None,
tp_group_rank,
tp_group_size,
)[0]
class FlashSelfAttention(torch.nn.Module):
def __init__(
self,
attention_dropout=0.0,
):
super().__init__()
self.dropout_p = attention_dropout
def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None):
if cu_seqlens is None:
output = flash_attn_func(q, k, v, dropout_p=self.dropout_p)
else:
raise ValueError("cu_seqlens is not supported!")
return output
def safediv(n, d):
q, r = divmod(n, d)
assert r == 0
return q
class MultiQueryAttention(nn.Module):
def __init__(self, cfg, layer_id=None):
super().__init__()
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.max_seq_len = cfg.seq_length
self.use_flash_attention = cfg.use_flash_attn
assert self.use_flash_attention, "FlashAttention is required!"
self.n_groups = cfg.num_attention_groups
self.tp_size = 1
self.n_local_heads = cfg.num_attention_heads
self.n_local_groups = self.n_groups
self.wqkv = nn.Linear(
cfg.hidden_size,
cfg.hidden_size + self.head_dim * 2 * self.n_groups,
bias=False,
)
self.wo = nn.Linear(
cfg.hidden_size,
cfg.hidden_size,
bias=False,
)
# assert self.use_flash_attention, 'non-Flash attention not supported yet.'
self.core_attention = FlashSelfAttention(
attention_dropout=cfg.attention_dropout
)
# self.core_attention = LocalAttention(
# num_heads = self.n_local_heads,
# head_size = self.head_dim,
# # num_kv_heads = self.n_local_groups,
# casual = True,
# supported_attention_backends = [_Backend.FLASH_ATTN, _Backend.TORCH_SDPA], # RIVER TODO
# )
self.layer_id = layer_id
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None,
cu_seqlens: torch.Tensor | None,
max_seq_len: torch.Tensor | None,
):
seqlen, bsz, dim = x.shape
xqkv = self.wqkv(x)
xq, xkv = torch.split(
xqkv,
(dim // self.tp_size, self.head_dim * 2 * self.n_groups // self.tp_size),
dim=-1,
)
# gather on 1st dimension
xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
xk, xv = xkv.chunk(2, -1)
# rotary embedding + flash attn
xq = rearrange(xq, "s b h d -> b s h d")
xk = rearrange(xk, "s b h d -> b s h d")
xv = rearrange(xv, "s b h d -> b s h d")
# q_per_kv = self.n_local_heads // self.n_local_groups
# if q_per_kv > 1:
# b, s, h, d = xk.size()
# if h == 1:
# xk = xk.expand(b, s, q_per_kv, d)
# xv = xv.expand(b, s, q_per_kv, d)
# else:
# ''' To cover the cases where h > 1, we have
# the following implementation, which is equivalent to:
# xk = xk.repeat_interleave(q_per_kv, dim=-2)
# xv = xv.repeat_interleave(q_per_kv, dim=-2)
# but can avoid calling aten::item() that involves cpu.
# '''
# idx = torch.arange(q_per_kv * h, device=xk.device).reshape(q_per_kv, -1).permute(1, 0).flatten()
# xk = torch.index_select(xk.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
# xv = torch.index_select(xv.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
if self.use_flash_attention:
output = self.core_attention(xq, xk, xv)
# reduce-scatter only support first dimension now
output = rearrange(output, "b s h d -> s b (h d)").contiguous()
else:
xq, xk, xv = [
rearrange(x, "b s ... -> s b ...").contiguous() for x in (xq, xk, xv)
]
output = self.core_attention(xq, xk, xv) # , mask)
output = self.wo(output)
return output
class FeedForward(nn.Module):
def __init__(
self,
cfg,
dim: int,
hidden_dim: int,
layer_id: int,
multiple_of: int = 256,
):
super().__init__()
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.swiglu = swiglu
self.w1 = nn.Linear(
dim,
2 * hidden_dim,
bias=False,
)
self.w2 = nn.Linear(
hidden_dim,
dim,
bias=False,
)
def forward(self, x):
x = self.swiglu(self.w1(x))
output = self.w2(x)
return output
class TransformerBlock(nn.Module):
def __init__(self, cfg, layer_id: int):
super().__init__()
self.n_heads = cfg.num_attention_heads
self.dim = cfg.hidden_size
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
self.attention = MultiQueryAttention(
cfg,
layer_id=layer_id,
)
self.feed_forward = FeedForward(
cfg,
dim=cfg.hidden_size,
hidden_dim=cfg.ffn_hidden_size,
layer_id=layer_id,
)
self.layer_id = layer_id
self.attention_norm = StepVideoRMSNorm(
cfg.hidden_size,
eps=cfg.layernorm_epsilon,
)
self.ffn_norm = StepVideoRMSNorm(
cfg.hidden_size,
eps=cfg.layernorm_epsilon,
)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None,
cu_seqlens: torch.Tensor | None,
max_seq_len: torch.Tensor | None,
):
residual = self.attention.forward(
self.attention_norm(x), mask, cu_seqlens, max_seq_len
)
h = x + residual
ffn_res = self.feed_forward.forward(self.ffn_norm(h))
out = h + ffn_res
return out
class Transformer(nn.Module):
def __init__(
self,
config,
max_seq_size=8192,
):
super().__init__()
self.num_layers = config.num_layers
self.layers = self._build_layers(config)
def _build_layers(self, config):
layers = torch.nn.ModuleList()
for layer_id in range(self.num_layers):
layers.append(
TransformerBlock(
config,
layer_id=layer_id + 1,
)
)
return layers
def forward(
self,
hidden_states,
attention_mask,
cu_seqlens=None,
max_seq_len=None,
):
if max_seq_len is not None and not isinstance(max_seq_len, torch.Tensor):
max_seq_len = torch.tensor(max_seq_len, dtype=torch.int32, device="cpu")
for lid, layer in enumerate(self.layers):
hidden_states = layer(
hidden_states,
attention_mask,
cu_seqlens,
max_seq_len,
)
return hidden_states
class Step1Model(PreTrainedModel):
config_class = PretrainedConfig
@with_empty_init
def __init__(
self,
config,
):
super().__init__(config)
self.tok_embeddings = LLaMaEmbedding(config)
self.transformer = Transformer(config)
def forward(
self,
input_ids=None,
attention_mask=None,
):
hidden_states = self.tok_embeddings(input_ids)
hidden_states = self.transformer(
hidden_states,
attention_mask,
)
return hidden_states
class STEP1TextEncoder(torch.nn.Module):
def __init__(self, model_dir, max_length=320):
super().__init__()
self.max_length = max_length
self.text_tokenizer = Wrapped_StepChatTokenizer(
os.path.join(model_dir, "step1_chat_tokenizer.model")
)
text_encoder = Step1Model.from_pretrained(model_dir)
self.text_encoder = text_encoder.eval().to(torch.bfloat16)
@torch.no_grad
def forward(self, prompts, with_mask=True, max_length=None):
self.device = next(self.text_encoder.parameters()).device
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
if type(prompts) is str:
prompts = [prompts]
txt_tokens = self.text_tokenizer(
prompts,
max_length=max_length or self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
y = self.text_encoder(
txt_tokens.input_ids.to(self.device),
attention_mask=(
txt_tokens.attention_mask.to(self.device) if with_mask else None
),
)
y_mask = txt_tokens.attention_mask
return y.transpose(0, 1), y_mask
EntryClass = STEP1TextEncoder
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from transformers: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/t5/modeling_t5.py
# Derived from T5 implementation posted on HuggingFace; license below:
#
# coding=utf-8
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch T5 & UMT5 model."""
import math
from collections.abc import Iterable
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from torch import nn
from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput, T5Config
from sglang.multimodal_gen.runtime.distributed import get_tp_rank, get_tp_world_size
from sglang.multimodal_gen.runtime.layers.activation import get_act_fn
from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm
from sglang.multimodal_gen.runtime.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig
from sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader
from sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder
from sglang.multimodal_gen.runtime.platforms import current_platform
class AttentionType:
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER = "decoder"
# Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER = "encoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY = "encoder_only"
# Attention between dec. Q and enc. K/V for encoder-decoder
ENCODER_DECODER = "encoder_decoder"
_seen_keys = set() # 用集合记录已经出现过的 key
@dataclass
class AttentionMetadata:
attn_bias: torch.Tensor
class T5DenseActDense(nn.Module):
def __init__(
self, config: T5Config, quant_config: QuantizationConfig | None = None
):
super().__init__()
self.wi = MergedColumnParallelLinear(config.d_model, [config.d_ff], bias=False)
self.wo = RowParallelLinear(
config.d_ff, config.d_model, bias=False, quant_config=quant_config
)
self.act = get_act_fn(config.dense_act_fn)
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.wo(hidden_states)
return hidden_states
class T5DenseGatedActDense(nn.Module):
def __init__(
self, config: T5Config, quant_config: QuantizationConfig | None = None
):
super().__init__()
self.wi_0 = MergedColumnParallelLinear(
config.d_model, [config.d_ff], bias=False, quant_config=quant_config
)
self.wi_1 = MergedColumnParallelLinear(
config.d_model, [config.d_ff], bias=False, quant_config=quant_config
)
# Should not run in fp16 unless mixed-precision is used,
# see https://github.com/huggingface/transformers/issues/20287.
self.wo = RowParallelLinear(
config.d_ff, config.d_model, bias=False, quant_config=quant_config
)
self.act = get_act_fn(config.dense_act_fn)
def forward(self, hidden_states) -> torch.Tensor:
hidden_gelu = self.act(self.wi_0(hidden_states)[0])
hidden_linear, _ = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states, _ = self.wo(hidden_states)
return hidden_states
class T5LayerFF(nn.Module):
def __init__(
self, config: T5Config, quant_config: QuantizationConfig | None = None
):
super().__init__()
if config.is_gated_act:
self.DenseReluDense = T5DenseGatedActDense(
config, quant_config=quant_config
)
else:
self.DenseReluDense = T5DenseActDense(config, quant_config=quant_config)
self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
def forward(self, hidden_states) -> torch.Tensor:
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + forwarded_states
return hidden_states
# T5 has attn_bias and does not use softmax scaling
class T5MultiHeadAttention(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, q, k, v, attn_bias=None):
b, _, n, c = q.shape
attn = torch.einsum("binc,bjnc->bnij", q, k)
if attn_bias is not None:
attn += attn_bias
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
x = torch.einsum("bnij,bjnc->binc", attn, v)
x = x.reshape(b, -1, n * c)
return x
class T5Attention(nn.Module):
def __init__(
self,
config: T5Config,
attn_type: str,
has_relative_attention_bias=False,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.attn_type = attn_type
# Cross-attention has no relative pos encoding anyway
self.is_decoder = attn_type == AttentionType.DECODER
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.total_num_heads = self.total_num_kv_heads = config.num_heads
# Partition heads across multiple tensor parallel GPUs.
tp_world_size = get_tp_world_size()
assert config.num_heads % tp_world_size == 0
self.n_heads = config.num_heads // tp_world_size
self.inner_dim = self.n_heads * self.key_value_proj_dim
# No GQA in t5.
# self.n_kv_heads = self.n_heads
self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.attn = T5MultiHeadAttention()
if self.has_relative_attention_bias:
self.relative_attention_bias = VocabParallelEmbedding(
self.relative_attention_num_buckets,
self.total_num_heads,
org_num_embeddings=self.relative_attention_num_buckets,
padding_size=self.relative_attention_num_buckets,
quant_config=quant_config,
)
self.o = RowParallelLinear(
self.d_model,
self.d_model,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
@staticmethod
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
) -> torch.Tensor:
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position,
i.e. the distance in tokens from the attending position to the
attended-to position. If bidirectional=False, then positive relative
positions are invalid. We use smaller buckets for small absolute
relative_position and larger buckets for larger absolute
relative_positions. All relative positions >=max_distance map to the
same bucket. All relative positions <=-max_distance map to the same
bucket. This should allow for more graceful generalization to longer
sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
""" # noqa: E501
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(
relative_position, torch.zeros_like(relative_position)
)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins
# in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large,
torch.full_like(relative_position_if_large, num_buckets - 1),
)
relative_buckets += torch.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets
def compute_bias(self, query_length, key_length, device=None) -> torch.Tensor:
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[
:, None
]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
None, :
]
# max_seq_len, nh
relative_position = memory_position - context_position
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(
relative_position_bucket
) # shape (query_length, key_length, num_heads)
x = values.permute([2, 0, 1]).unsqueeze(
0
) # shape (1, num_heads, query_length, key_length)
return x
def forward(
self,
hidden_states: torch.Tensor, # (num_tokens, d_model)
attention_mask: torch.Tensor,
attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor:
bs, seq_len, _ = hidden_states.shape
num_seqs = bs
n, c = self.n_heads, self.d_model // self.total_num_heads
qkv, _ = self.qkv_proj(hidden_states)
# Projection of 'own' hidden state (self-attention). No GQA here.
q, k, v = qkv.split(self.inner_dim, dim=-1)
q = q.reshape(bs, seq_len, n, c)
k = k.reshape(bs, seq_len, n, c)
v = v.reshape(bs, seq_len, n, c)
assert attn_metadata is not None
attn_bias = attn_metadata.attn_bias
# Not compatible with CP here (as all encoder-decoder models),
# as it assumes homogeneous batch (prefills or decodes).
if self.has_relative_attention_bias:
# Self-attention. Compute T5 relative positional encoding.
# The bias term is computed on longest sequence in batch. Biases
# for shorter sequences are slices of the longest.
assert self.attn_type == AttentionType.ENCODER
attn_bias = self.compute_bias(seq_len, seq_len).repeat(num_seqs, 1, 1, 1)
attn_metadata.attn_bias = attn_bias
else:
# Encoder/Decoder Self-Attention Layer, attn bias already cached.
assert attn_bias is not None
if attention_mask is not None:
attention_mask = (
attention_mask.view(bs, 1, 1, -1)
if attention_mask.ndim == 2
else attention_mask.unsqueeze(1)
)
mask_val = -1e4 if current_platform.is_mps() else torch.finfo(q.dtype).min
attn_bias.masked_fill_(attention_mask == 0, mask_val)
if get_tp_world_size() > 1:
rank = get_tp_rank()
attn_bias = attn_bias[
:, rank * self.n_heads : (rank + 1) * self.n_heads, :, :
]
attn_output = self.attn(q, k, v, attn_bias)
output, _ = self.o(attn_output)
return output
class T5LayerSelfAttention(nn.Module):
def __init__(
self,
config,
has_relative_attention_bias=False,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.SelfAttention = T5Attention(
config,
AttentionType.DECODER if "decoder" in prefix else AttentionType.ENCODER,
has_relative_attention_bias=has_relative_attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.SelfAttention",
)
self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
hidden_states=normed_hidden_states,
attention_mask=attention_mask,
attn_metadata=attn_metadata,
)
hidden_states = hidden_states + attention_output
return hidden_states
class T5LayerCrossAttention(nn.Module):
def __init__(
self, config, quant_config: QuantizationConfig | None = None, prefix: str = ""
):
super().__init__()
self.EncDecAttention = T5Attention(
config,
AttentionType.ENCODER_DECODER,
has_relative_attention_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.EncDecAttention",
)
self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
hidden_states=normed_hidden_states,
attn_metadata=attn_metadata,
)
hidden_states = hidden_states + attention_output
return hidden_states
class T5Block(nn.Module):
def __init__(
self,
config: T5Config,
is_decoder: bool,
has_relative_attention_bias=False,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.is_decoder = is_decoder
self.layer = nn.ModuleList()
self.layer.append(
T5LayerSelfAttention(
config,
has_relative_attention_bias=has_relative_attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
)
if self.is_decoder:
self.layer.append(
T5LayerCrossAttention(
config, quant_config=quant_config, prefix=f"{prefix}.cross_attn"
)
)
self.layer.append(T5LayerFF(config, quant_config=quant_config))
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor:
if attention_mask is None:
attention_mask = torch.ones(
hidden_states.shape[:2], device=hidden_states.device
)
hidden_states = self.layer[0](
hidden_states=hidden_states,
attention_mask=attention_mask,
attn_metadata=attn_metadata,
)
if self.is_decoder:
hidden_states = self.layer[1](
hidden_states=hidden_states, attn_metadata=attn_metadata
)
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
return hidden_states
class T5Stack(nn.Module):
def __init__(
self,
config: T5Config,
is_decoder: bool,
n_layers: int,
embed_tokens=None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
is_umt5: bool = False,
):
super().__init__()
self.embed_tokens = embed_tokens
self.is_umt5 = is_umt5
if is_umt5:
self.block = nn.ModuleList(
[
T5Block(
config,
is_decoder=is_decoder,
has_relative_attention_bias=True,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{i}",
)
for i in range(n_layers)
]
)
else:
# Only the first block has relative positional encoding.
self.block = nn.ModuleList(
[
T5Block(
config,
is_decoder=is_decoder,
has_relative_attention_bias=i == 0,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{i}",
)
for i in range(n_layers)
]
)
self.final_layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for idx, block in enumerate(self.block):
hidden_states = block(
hidden_states=hidden_states,
attention_mask=attention_mask,
attn_metadata=attn_metadata,
)
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class T5EncoderModel(TextEncoder):
def __init__(self, config: T5Config, prefix: str = ""):
super().__init__(config)
quant_config = None
self.shared = VocabParallelEmbedding(
config.vocab_size, config.d_model, org_num_embeddings=config.vocab_size
)
self.encoder = T5Stack(
config,
False,
config.num_layers,
self.shared,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
is_umt5=False,
)
def get_input_embeddings(self):
return self.shared
def forward(
self,
input_ids: torch.Tensor | None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
output_hidden_states: bool | None = None,
**kwargs,
) -> BaseEncoderOutput:
attn_metadata = AttentionMetadata(None)
hidden_states = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
attn_metadata=attn_metadata,
)
return BaseEncoderOutput(last_hidden_state=hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q", "q"),
(".qkv_proj", ".k", "k"),
(".qkv_proj", ".v", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
loaded = False
if "decoder" in name or "lm_head" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded = True
break
if not loaded:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class UMT5EncoderModel(TextEncoder):
def __init__(self, config: T5Config, prefix: str = ""):
super().__init__(config)
quant_config = None
self.shared = VocabParallelEmbedding(
config.vocab_size, config.d_model, org_num_embeddings=config.vocab_size
)
self.encoder = T5Stack(
config,
False,
config.num_layers,
self.shared,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
is_umt5=True,
)
def get_input_embeddings(self):
return self.shared
def forward(
self,
input_ids: torch.Tensor | None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
output_hidden_states: bool | None = None,
**kwargs,
) -> BaseEncoderOutput:
attn_metadata = AttentionMetadata(None)
hidden_states = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
attn_metadata=attn_metadata,
)
return BaseEncoderOutput(
last_hidden_state=hidden_states,
attention_mask=attention_mask,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
loaded = False
if "decoder" in name or "lm_head" in name:
continue
for (
param_name,
weight_name,
shard_id,
) in self.config.arch_config.stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded = True
break
if not loaded:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
EntryClass = [T5EncoderModel, UMT5EncoderModel]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/vision.py
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
import torch
from transformers import PretrainedConfig
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
_C = TypeVar("_C", bound=PretrainedConfig)
class VisionEncoderInfo(ABC, Generic[_C]):
def __init__(self, vision_config: _C) -> None:
super().__init__()
self.vision_config = vision_config
@abstractmethod
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
raise NotImplementedError
@abstractmethod
def get_max_image_tokens(self) -> int:
raise NotImplementedError
@abstractmethod
def get_image_size(self) -> int:
raise NotImplementedError
@abstractmethod
def get_patch_size(self) -> int:
raise NotImplementedError
@abstractmethod
def get_patch_grid_length(self) -> int:
raise NotImplementedError
def resolve_visual_encoder_outputs(
encoder_outputs: torch.Tensor | list[torch.Tensor],
feature_sample_layers: list[int] | None,
post_layer_norm: torch.nn.LayerNorm | None,
max_possible_layers: int,
) -> torch.Tensor:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
handle post normalization and resolve it into a single output tensor.
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
max_possible_layers: Total layers in the fully loaded visual encoder.
"""
if feature_sample_layers is None:
if post_layer_norm is not None:
return post_layer_norm(encoder_outputs)
return encoder_outputs
# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
# NOTE: this assumes that encoder_outputs is a list containing
# the inputs to the visual encoder, followed by the hidden states
# of each layer.
num_loaded_layers = len(encoder_outputs) - 1
offset = max_possible_layers - num_loaded_layers
hs_pool = [
(
encoder_outputs[layer_idx]
if layer_idx >= 0
else encoder_outputs[layer_idx + offset]
)
for layer_idx in feature_sample_layers
]
# Apply post-norm on the final hidden state if we are using it
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/parameter.py
from collections.abc import Callable
from fractions import Fraction
from typing import Any
import torch
from torch.nn import Parameter
from sglang.multimodal_gen.runtime.distributed import get_tp_rank
from sglang.multimodal_gen.runtime.models.utils import _make_synced_weight_loader
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class BasevLLMParameter(Parameter):
"""
Base parameter for vLLM linear layers. Extends the torch.nn.parameter
by taking in a linear weight loader. Will copy the loaded weight
into the parameter when the provided weight loader is called.
"""
def __new__(cls, data: torch.Tensor, **kwargs):
return super().__new__(cls, data=data, requires_grad=False)
def __init__(self, data: torch.Tensor, weight_loader: Callable):
"""
Initialize the BasevLLMParameter
:param data: torch tensor with the parameter data
:param weight_loader: weight loader callable
:returns: a torch.nn.parameter
"""
# During weight loading, we often do something like:
# narrowed_tensor = param.data.narrow(0, offset, len)
# narrowed_tensor.copy_(real_weight)
# expecting narrowed_tensor and param.data to share the same storage.
# However, on TPUs, narrowed_tensor will lazily propagate to the base
# tensor, which is param.data, leading to the redundant memory usage.
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
from sglang.multimodal_gen.runtime.platforms import current_platform
if current_platform.is_tpu():
weight_loader = _make_synced_weight_loader(weight_loader)
self._weight_loader = weight_loader
@property
def weight_loader(self):
return self._weight_loader
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
cond1 = self.data.ndim == 1 and self.data.numel() == 1
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
return cond1 and cond2
def _assert_and_load(self, loaded_weight: torch.Tensor) -> None:
assert self.data.shape == loaded_weight.shape or self._is_1d_and_scalar(
loaded_weight
)
self.data.copy_(loaded_weight)
def load_column_parallel_weight(self, loaded_weight: torch.Tensor) -> None:
self._assert_and_load(loaded_weight)
def load_row_parallel_weight(self, loaded_weight: torch.Tensor) -> None:
self._assert_and_load(loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None:
self._assert_and_load(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None:
self._assert_and_load(loaded_weight)
class _ColumnvLLMParameter(BasevLLMParameter):
"""
Private class defining weight loading functionality
(load_merged_column_weight, load_qkv_weight)
for parameters being loaded into linear layers with column
parallelism. This includes QKV and MLP layers which are
not already fused on disk. Requires an output dimension
to be defined. Called within the weight loader of
each of the column parallel linear layers.
"""
def __init__(self, output_dim: int, **kwargs):
self._output_dim = output_dim
super().__init__(**kwargs)
@property
def output_dim(self):
return self._output_dim
def load_column_parallel_weight(self, loaded_weight: torch.Tensor) -> None:
tp_rank = get_tp_rank()
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None:
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
if shard_offset is None or shard_size is None:
raise ValueError("shard_offset and shard_size must be provided")
if (
isinstance(self, PackedColumnParameter | PackedvLLMParameter)
and self.packed_dim == self.output_dim
):
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size
)
param_data = self.data
tp_rank = get_tp_rank()
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(
self.output_dim, tp_rank * shard_size, shard_size
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None:
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
shard_id = kwargs.get("shard_id")
num_heads = kwargs.get("num_heads")
assert shard_offset is not None
assert shard_size is not None
assert shard_id is not None
assert num_heads is not None
if (
isinstance(self, PackedColumnParameter | PackedvLLMParameter)
and self.output_dim == self.packed_dim
):
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size
)
param_data = self.data
tp_rank = get_tp_rank()
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class RowvLLMParameter(BasevLLMParameter):
"""
Parameter class defining weight_loading functionality
(load_row_parallel_weight) for parameters being loaded
into linear layers with row parallel functionality.
Requires an input_dim to be defined.
"""
def __init__(self, input_dim: int, **kwargs):
self._input_dim = input_dim
super().__init__(**kwargs)
@property
def input_dim(self):
return self._input_dim
def load_row_parallel_weight(self, loaded_weight: torch.Tensor) -> None:
tp_rank = get_tp_rank()
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(
self.input_dim, tp_rank * shard_size, shard_size
)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
grouped quantization. Uses both column and row parallelism.
"""
pass
class ChannelQuantScaleParameter(_ColumnvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
"""
pass
class PerTensorScaleParameter(BasevLLMParameter):
"""
Parameter class for scales where the number of scales is
equivalent to the number of logical matrices in fused linear
layers (e.g. for QKV, there are 3 scales loaded from disk).
This is relevant to weights with per-tensor quantization.
Adds functionality to map the scalers to a shard during
weight loading.
Note: additional parameter manipulation may be handled
for each quantization config specifically, within
process_weights_after_loading
"""
def __init__(self, **kwargs):
self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
super().__init__(**kwargs)
def _shard_id_as_int(self, shard_id: str | int) -> int:
if isinstance(shard_id, int):
return shard_id
# if not int, assume shard_id for qkv
# map to int and return
assert isinstance(shard_id, str)
assert shard_id in self.qkv_idxs
return self.qkv_idxs[shard_id]
# For row parallel layers, no sharding needed
# load weight into parameter as is
def load_row_parallel_weight(self, *args, **kwargs) -> None:
super().load_row_parallel_weight(*args, **kwargs)
def load_merged_column_weight(self, *args, **kwargs) -> None:
self._load_into_shard_id(*args, **kwargs)
def load_qkv_weight(self, *args, **kwargs) -> None:
self._load_into_shard_id(*args, **kwargs)
def load_column_parallel_weight(self, *args, **kwargs) -> None:
super().load_row_parallel_weight(*args, **kwargs)
def _load_into_shard_id(
self, loaded_weight: torch.Tensor, shard_id: str | int, **kwargs
):
"""
Slice the parameter data based on the shard id for
loading.
"""
param_data = self.data
shard_id = self._shard_id_as_int(shard_id)
# AutoFP8 scales do not have a shape
# compressed-tensors scales do have a shape
if len(loaded_weight.shape) != 0:
assert loaded_weight.shape[0] == 1
loaded_weight = loaded_weight[0]
param_data = param_data[shard_id]
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class PackedColumnParameter(_ColumnvLLMParameter):
"""
Parameter for model parameters which are packed on disk
and support column parallelism only. See PackedvLLMParameter
for more details on the packed properties.
"""
def __init__(self, packed_factor: int | Fraction, packed_dim: int, **kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
super().__init__(**kwargs)
@property
def packed_dim(self):
return self._packed_dim
@property
def packed_factor(self):
return self._packed_factor
def adjust_shard_indexes_for_packing(
self, shard_size, shard_offset
) -> tuple[Any, Any]:
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor,
)
class PackedvLLMParameter(ModelWeightParameter):
"""
Parameter for model weights which are packed on disk.
Example: GPTQ Marlin weights are int4 or int8, packed into int32.
Extends the ModelWeightParameter to take in the
packed factor, the packed dimension, and optionally, marlin
tile size for marlin kernels. Adjusts the shard_size and
shard_offset for fused linear layers model weight loading
by accounting for packing and optionally, marlin tile size.
"""
def __init__(self, packed_factor: int | Fraction, packed_dim: int, **kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
super().__init__(**kwargs)
@property
def packed_dim(self):
return self._packed_dim
@property
def packed_factor(self):
return self._packed_factor
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor,
)
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
block-wise quantization. Uses both column and row parallelism.
"""
pass
def permute_param_layout_(
param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs
) -> BasevLLMParameter:
"""
Permute a parameter's layout to the specified input and output dimensions,
useful for forcing the parameter into a known layout, for example, if I need
a packed (quantized) weight matrix to be in the layout
{input_dim = 0, output_dim = 1, packed_dim = 0}
then I can call:
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
to ensure x is in the correct layout (permuting it to the correct layout if
required, asserting if it cannot get it to the correct layout)
"""
curr_input_dim = getattr(param, "input_dim", None)
curr_output_dim = getattr(param, "output_dim", None)
if curr_input_dim is None or curr_output_dim is None:
assert param.data.dim() == 2, (
"permute_param_layout_ only supports 2D parameters when either "
"input_dim or output_dim is not set"
)
# if one of the dimensions is not set, set it to the opposite of the other
# we can only do this since we asserted the parameter is 2D above
if curr_input_dim is None:
assert curr_output_dim is not None, "either input or output dim must be set"
curr_input_dim = (curr_output_dim + 1) % 2
if curr_output_dim is None:
assert curr_input_dim is not None, "either input or output dim must be set"
curr_output_dim = (curr_input_dim + 1) % 2
# create permutation from the current layout to the layout with
# self.input_dim at input_dim and self.output_dim at output_dim preserving
# other dimensions
perm = [
i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim]
]
perm.insert(input_dim, curr_input_dim)
perm.insert(output_dim, curr_output_dim)
if "packed_dim" in kwargs:
assert (
hasattr(param, "packed_dim")
and param.packed_dim == perm[kwargs["packed_dim"]]
), "permute_param_layout_ currently doesn't support repacking"
param.data = param.data.permute(*perm)
if hasattr(param, "_input_dim"):
param._input_dim = input_dim
if hasattr(param, "_output_dim"):
param._output_dim = output_dim
if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
param._packed_dim = kwargs["packed_dim"]
return param
def _adjust_shard_indexes_for_packing(
shard_size, shard_offset, packed_factor
) -> tuple[Any, Any]:
shard_size = shard_size // packed_factor
shard_offset = shard_offset // packed_factor
return shard_size, shard_offset
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/registry.py
import ast
import importlib
import os
import pickle
import subprocess
import sys
import tempfile
from abc import ABC, abstractmethod
from collections.abc import Callable, Set
from dataclasses import dataclass, field
from functools import lru_cache
from typing import NoReturn, TypeVar, cast
import cloudpickle
from torch import nn
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
MODELS_PATH = os.path.dirname(__file__)
COMPONENT_DIRS = [
d
for d in os.listdir(MODELS_PATH)
if os.path.isdir(os.path.join(MODELS_PATH, d))
and not d.startswith("__")
and not d.startswith(".")
]
_IMAGE_ENCODER_MODELS: dict[str, tuple] = {
# "HunyuanVideoTransformer3DModel": ("image_encoder", "hunyuanvideo", "HunyuanVideoImageEncoder"),
"CLIPVisionModelWithProjection": ("encoders", "clip", "CLIPVisionModel"),
}
@lru_cache(maxsize=None)
def _discover_and_register_models() -> dict[str, tuple[str, str, str]]:
discovered_models = _IMAGE_ENCODER_MODELS
for component in COMPONENT_DIRS:
component_path = os.path.join(MODELS_PATH, component)
for filename in os.listdir(component_path):
if not filename.endswith(".py"):
continue
mod_relname = filename[:-3]
filepath = os.path.join(component_path, filename)
try:
with open(filepath, "r", encoding="utf-8") as f:
source = f.read()
tree = ast.parse(source, filename=filename)
entry_class_node = None
first_class_def = None
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if (
isinstance(target, ast.Name)
and target.id == "EntryClass"
):
entry_class_node = node
break
if first_class_def is None and isinstance(node, ast.ClassDef):
first_class_def = node
if entry_class_node and first_class_def:
model_cls_name_list = []
value_node = entry_class_node.value
# EntryClass = ClassName
if isinstance(value_node, ast.Name):
model_cls_name_list.append(value_node.id)
# EntryClass = ["...", ClassName, ...]
elif isinstance(value_node, (ast.List, ast.Tuple)):
for elt in value_node.elts:
if isinstance(elt, ast.Constant):
model_cls_name_list.append(elt.value)
elif isinstance(elt, ast.Name):
model_cls_name_list.append(elt.id)
if model_cls_name_list:
for model_cls_str in model_cls_name_list:
if model_cls_str in discovered_models:
logger.warning(
f"Duplicate architecture found: {model_cls_str}. It will be overwritten."
)
model_arch = model_cls_str
discovered_models[model_arch] = (
component,
mod_relname,
model_cls_str,
)
except Exception as e:
logger.warning(f"Could not parse {filepath} to find models: {e}")
return discovered_models
_SGL_DIFFUSION_MODELS = _discover_and_register_models()
_SUBPROCESS_COMMAND = [
sys.executable,
"-m",
"sglang.multimodal_gen.runtime.models.dits.registry",
]
_T = TypeVar("_T")
@dataclass(frozen=True)
class _ModelInfo:
architecture: str
@staticmethod
def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
return _ModelInfo(
architecture=model.__name__,
)
class _BaseRegisteredModel(ABC):
@abstractmethod
def inspect_model_cls(self) -> _ModelInfo:
raise NotImplementedError
@abstractmethod
def load_model_cls(self) -> type[nn.Module]:
raise NotImplementedError
@dataclass(frozen=True)
class _RegisteredModel(_BaseRegisteredModel):
"""
Represents a model that has already been imported in the main process.
"""
interfaces: _ModelInfo
model_cls: type[nn.Module]
@staticmethod
def from_model_cls(model_cls: type[nn.Module]):
return _RegisteredModel(
interfaces=_ModelInfo.from_model_cls(model_cls),
model_cls=model_cls,
)
def inspect_model_cls(self) -> _ModelInfo:
return self.interfaces
def load_model_cls(self) -> type[nn.Module]:
return self.model_cls
def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
# NOTE: We use a temporary directory instead of a temporary file to avoid
# issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
with tempfile.TemporaryDirectory() as tempdir:
output_filepath = os.path.join(tempdir, "registry_output.tmp")
# `cloudpickle` allows pickling lambda functions directly
input_bytes = cloudpickle.dumps((fn, output_filepath))
# cannot use `sys.executable __file__` here because the script
# contains relative imports
returned = subprocess.run(
_SUBPROCESS_COMMAND, input=input_bytes, capture_output=True
)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(
f"Error raised in subprocess:\n" f"{returned.stderr.decode()}"
) from e
with open(output_filepath, "rb") as f:
return cast(_T, pickle.load(f))
@dataclass(frozen=True)
class _LazyRegisteredModel(_BaseRegisteredModel):
"""
Represents a model that has not been imported in the main process.
"""
module_name: str
component_name: str
class_name: str
# Performed in another process to avoid initializing CUDA
def inspect_model_cls(self) -> _ModelInfo:
return _run_in_subprocess(
lambda: _ModelInfo.from_model_cls(self.load_model_cls())
)
def load_model_cls(self) -> type[nn.Module]:
mod = importlib.import_module(self.module_name)
return cast(type[nn.Module], getattr(mod, self.class_name))
@lru_cache(maxsize=128)
def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> type[nn.Module] | None:
from sglang.multimodal_gen.runtime.platforms import current_platform
current_platform.verify_model_arch(model_arch)
try:
return model.load_model_cls()
except Exception:
logger.exception("Ignore import error when loading '%s'", model_arch)
return None
@lru_cache(maxsize=128)
def _try_inspect_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> _ModelInfo | None:
try:
return model.inspect_model_cls()
except Exception:
logger.exception("Error in inspecting model architecture '%s'", model_arch)
return None
@dataclass
class _ModelRegistry:
# Keyed by model_arch
models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
def get_supported_archs(self) -> Set[str]:
return self.models.keys()
def register_model(
self,
model_arch: str,
model_cls: type[nn.Module] | str,
) -> None:
"""
Register an external model to be used in vLLM.
:code:`model_cls` can be either:
- A :class:`torch.nn.Module` class directly referencing the model.
- A string in the format :code:`<module>:<class>` which can be used to
lazily import the model. This is useful to avoid initializing CUDA
when importing the model and thus the related error
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
"""
if model_arch in self.models:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.",
model_arch,
model_cls,
)
if isinstance(model_cls, str):
split_str = model_cls.split(":")
if len(split_str) != 2:
msg = "Expected a string in the format `<module>:<class>`"
raise ValueError(msg)
model = _LazyRegisteredModel(*split_str)
else:
model = _RegisteredModel.from_model_cls(model_cls)
self.models[model_arch] = model
def _raise_for_unsupported(self, architectures: list[str]) -> NoReturn:
all_supported_archs = self.get_supported_archs()
if any(arch in all_supported_archs for arch in architectures):
raise ValueError(
f"Model architectures {architectures} failed "
"to be inspected. Please check the logs for more details."
)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {all_supported_archs}"
)
def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None:
if model_arch not in self.models:
return None
return _try_load_model_cls(model_arch, self.models[model_arch])
def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None:
if model_arch not in self.models:
return None
return _try_inspect_model_cls(model_arch, self.models[model_arch])
def _normalize_archs(
self,
architectures: str | list[str],
) -> list[str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")
normalized_arch = []
for model in architectures:
if model not in self.models:
raise Exception(
f"Unsupported model architecture: {model}. Registered architectures: {architectures}"
)
model = "TransformersModel"
normalized_arch.append(model)
return normalized_arch
def inspect_model_cls(
self,
architectures: str | list[str],
) -> tuple[_ModelInfo, str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return (model_info, arch)
return self._raise_for_unsupported(architectures)
def resolve_model_cls(
self,
architectures: str | list[str],
) -> tuple[type[nn.Module], str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
return self._raise_for_unsupported(architectures)
ModelRegistry = _ModelRegistry(
{
model_arch: _LazyRegisteredModel(
module_name=f"sglang.multimodal_gen.runtime.models.{component_name}.{mod_relname}",
component_name=component_name,
class_name=cls_name,
)
for model_arch, (
component_name,
mod_relname,
cls_name,
) in _SGL_DIFFUSION_MODELS.items()
}
)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
import torch
class BaseScheduler(ABC):
timesteps: torch.Tensor
order: int
num_train_timesteps: int
def __init__(self, *args, **kwargs) -> None:
# Check if subclass has defined all required properties
required_attributes = ["timesteps", "order", "num_train_timesteps"]
for attr in required_attributes:
if not hasattr(self, attr):
raise AttributeError(
f"Subclasses of BaseScheduler must define '{attr}' property"
)
@abstractmethod
def set_shift(self, shift: float) -> None:
pass
@abstractmethod
def set_timesteps(self, *args, **kwargs) -> None:
pass
@abstractmethod
def scale_model_input(
self, sample: torch.Tensor, timestep: int | None = None
) -> torch.Tensor:
pass
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modified from diffusers==0.29.2
#
# ==============================================================================
import math
from dataclasses import dataclass
from typing import Any
import numpy as np
import scipy.stats
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import BaseOutput
from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin, BaseScheduler):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
use_dynamic_shifting (`bool`, defaults to False):
Whether to apply timestep shifting on-the-fly based on the image resolution.
base_shift (`float`, defaults to 0.5):
Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent
with desired output.
max_shift (`float`, defaults to 1.15):
Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be
more exaggerated or stylized.
base_image_seq_len (`int`, defaults to 256):
The base image sequence length.
max_image_seq_len (`int`, defaults to 4096):
The maximum image sequence length.
invert_sigmas (`bool`, defaults to False):
Whether to invert the sigmas.
shift_terminal (`float`, defaults to None):
The end value of the shifted timestep schedule.
use_karras_sigmas (`bool`, defaults to False):
Whether to use Karras sigmas for step sizes in the noise schedule during sampling.
use_exponential_sigmas (`bool`, defaults to False):
Whether to use exponential sigmas for step sizes in the noise schedule during sampling.
use_beta_sigmas (`bool`, defaults to False):
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
time_shift_type (`str`, defaults to "exponential"):
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
stochastic_sampling (`bool`, defaults to False):
Whether to use stochastic sampling.
"""
_compatibles: list[Any] = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
use_dynamic_shifting: bool = False,
base_shift: float | None = 0.5,
max_shift: float | None = 1.15,
base_image_seq_len: int | None = 256,
max_image_seq_len: int | None = 4096,
invert_sigmas: bool = False,
shift_terminal: float | None = None,
use_karras_sigmas: bool | None = False,
use_exponential_sigmas: bool | None = False,
use_beta_sigmas: bool | None = False,
time_shift_type: str = "exponential",
stochastic_sampling: bool = False,
):
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if time_shift_type not in {"exponential", "linear"}:
raise ValueError(
"`time_shift_type` must either be 'exponential' or 'linear'."
)
timesteps = np.linspace(
1, num_train_timesteps, num_train_timesteps, dtype=np.float32
)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
if not use_dynamic_shifting:
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps
self.num_train_timesteps = num_train_timesteps
self._step_index: int | None = None
self._begin_index: int | None = None
self._shift = shift
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
BaseScheduler.__init__(self)
@property
def shift(self) -> float:
"""
The value used for shifting.
"""
return self._shift
@property
def step_index(self) -> int | None:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self) -> int | None:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_shift(self, shift: float) -> None:
self._shift = shift
def scale_noise(
self,
sample: torch.FloatTensor,
timestep: float | torch.FloatTensor,
noise: torch.FloatTensor | None = None,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
if sample.device.type == "mps" and torch.is_floating_point(timestep):
# mps does not support float64
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
assert isinstance(timestep, torch.Tensor)
timestep = timestep.to(sample.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(sample.device)
assert isinstance(timestep, torch.Tensor)
timestep = timestep.to(sample.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [
self.index_for_timestep(t, schedule_timesteps) for t in timestep
]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timestep.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timestep.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(sample.shape):
sigma = sigma.unsqueeze(-1)
sample = sigma * noise + (1.0 - sigma) * sample
return sample
def _sigma_to_t(self, sigma: float) -> float:
return sigma * self.config.num_train_timesteps
def time_shift(
self, mu: float, sigma: float, t: torch.Tensor | np.ndarray
) -> torch.Tensor | np.ndarray:
if self.config.time_shift_type == "exponential":
return self._time_shift_exponential(mu, sigma, t)
elif self.config.time_shift_type == "linear":
return self._time_shift_linear(mu, sigma, t)
else:
raise ValueError(f"Unknown time_shift_type: {self.config.time_shift_type}")
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
r"""
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
value.
Reference:
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
Args:
t (`torch.Tensor`):
A tensor of timesteps to be stretched and shifted.
Returns:
`torch.Tensor`:
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
"""
one_minus_z = 1 - t
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
stretched_t = 1 - (one_minus_z / scale_factor)
return stretched_t
def set_timesteps(
self,
num_inference_steps: int | None = None,
device: str | torch.device = None,
sigmas: list[float] | None = None,
mu: float | None = None,
timesteps: list[float] | None = None,
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
sigmas (`List[float]`, *optional*):
Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
automatically.
mu (`float`, *optional*):
Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep
shifting.
timesteps (`List[float]`, *optional*):
Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed
automatically.
"""
if self.config.use_dynamic_shifting and mu is None:
raise ValueError(
"`mu` must be passed when `use_dynamic_shifting` is set to be `True`"
)
if (
sigmas is not None
and timesteps is not None
and len(sigmas) != len(timesteps)
):
raise ValueError("`sigmas` and `timesteps` should have the same length")
if num_inference_steps is not None:
if (sigmas is not None and len(sigmas) != num_inference_steps) or (
timesteps is not None and len(timesteps) != num_inference_steps
):
raise ValueError(
"`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
)
else:
if sigmas is not None:
num_inference_steps = len(sigmas)
elif timesteps is not None:
num_inference_steps = len(timesteps)
else:
raise ValueError(
"Either num_inference_steps, sigmas, or timesteps must be provided"
)
self.num_inference_steps = num_inference_steps
# 1. Prepare default sigmas
is_timesteps_provided = timesteps is not None
timesteps_array: np.ndarray | None = None
if is_timesteps_provided:
assert timesteps is not None
timesteps_array = np.array(timesteps).astype(np.float32)
sigmas_array: np.ndarray
if sigmas is None:
if timesteps_array is None:
timesteps_array = np.linspace(
self._sigma_to_t(self.sigma_max),
self._sigma_to_t(self.sigma_min),
num_inference_steps,
)
sigmas_array = timesteps_array / self.config.num_train_timesteps
else:
sigmas_array = np.array(sigmas).astype(np.float32)
num_inference_steps = len(sigmas_array)
# 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
# "exponential" or "linear" type is applied
if self.config.use_dynamic_shifting:
assert mu is not None, "mu cannot be None when use_dynamic_shifting is True"
sigmas_array = self.time_shift(mu, 1.0, sigmas_array)
else:
sigmas_array = (
self.shift * sigmas_array / (1 + (self.shift - 1) * sigmas_array)
)
# 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
if self.config.shift_terminal:
sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32)
sigmas_tensor = self.stretch_shift_to_terminal(sigmas_tensor)
sigmas_array = sigmas_tensor.numpy()
# 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
if self.config.use_karras_sigmas:
sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32)
sigmas_tensor = self._convert_to_karras(
in_sigmas=sigmas_tensor, num_inference_steps=num_inference_steps
)
sigmas_array = sigmas_tensor.numpy()
elif self.config.use_exponential_sigmas:
sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32)
sigmas_tensor = self._convert_to_exponential(
in_sigmas=sigmas_tensor, num_inference_steps=num_inference_steps
)
sigmas_array = sigmas_tensor.numpy()
elif self.config.use_beta_sigmas:
sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32)
sigmas_tensor = self._convert_to_beta(
in_sigmas=sigmas_tensor, num_inference_steps=num_inference_steps
)
sigmas_array = sigmas_tensor.numpy()
# 5. Convert sigmas and timesteps to tensors and move to specified device
sigmas_tensor = torch.from_numpy(sigmas_array).to(
dtype=torch.float32, device=device
)
if not is_timesteps_provided:
timesteps_tensor = sigmas_tensor * self.config.num_train_timesteps
else:
assert timesteps_array is not None
timesteps_tensor = torch.from_numpy(timesteps_array).to(
dtype=torch.float32, device=device
)
# 6. Append the terminal sigma value.
# If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
# `invert_sigmas` flag can be set to `True`. This case is only required in Mochi
if self.config.invert_sigmas:
sigmas_tensor = 1.0 - sigmas_tensor
timesteps_tensor = sigmas_tensor * self.config.num_train_timesteps
sigmas_tensor = torch.cat(
[sigmas_tensor, torch.ones(1, device=sigmas_tensor.device)]
)
else:
sigmas_tensor = torch.cat(
[sigmas_tensor, torch.zeros(1, device=sigmas_tensor.device)]
)
self.timesteps = timesteps_tensor
self.sigmas = sigmas_tensor
self._step_index = None
self._begin_index = None
def index_for_timestep(
self,
timestep: float | torch.FloatTensor,
schedule_timesteps: torch.Tensor | None = None,
) -> int:
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep: float | torch.FloatTensor) -> None:
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.FloatTensor,
timestep: int | torch.Tensor,
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: torch.Generator | None = None,
per_token_timesteps: torch.Tensor | None = None,
return_dict: bool = True,
) -> FlowMatchEulerDiscreteSchedulerOutput | tuple[torch.FloatTensor, ...]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
per_token_timesteps (`torch.Tensor`, *optional*):
The timesteps for each token in the sample.
return_dict (`bool`):
Whether or not to return a
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned,
otherwise a tuple is returned where the first element is the sample tensor.
"""
if isinstance(timestep, int | torch.IntTensor | torch.LongTensor):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
if per_token_timesteps is not None:
per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
sigmas = self.sigmas[:, None, None]
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
lower_sigmas = lower_mask * sigmas
lower_sigmas, _ = lower_sigmas.max(dim=0)
current_sigma = per_token_sigmas[..., None]
next_sigma = lower_sigmas[..., None]
dt = current_sigma - next_sigma
else:
assert self.step_index is not None, "step_index should not be None"
sigma_idx = self.step_index
sigma = self.sigmas[sigma_idx]
sigma_next = self.sigmas[sigma_idx + 1]
current_sigma = sigma
next_sigma = sigma_next
dt = sigma_next - sigma
if self.config.stochastic_sampling:
x0 = sample - current_sigma * model_output
noise = torch.randn_like(sample)
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
else:
prev_sample = sample + dt * model_output
# upon completion increase step index by one
assert self._step_index is not None, "_step_index should not be None"
self._step_index += 1
if per_token_timesteps is None:
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
if isinstance(prev_sample, torch.Tensor | float) and not return_dict:
return (prev_sample,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(
self, in_sigmas: torch.Tensor, num_inference_steps: int
) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(
self, in_sigmas: torch.Tensor, num_inference_steps: int
) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = np.exp(
np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)
)
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self,
in_sigmas: torch.Tensor,
num_inference_steps: int,
alpha: float = 0.6,
beta: float = 0.6,
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
def _time_shift_exponential(
self, mu: float, sigma: float, t: torch.Tensor | np.ndarray
) -> torch.Tensor | np.ndarray:
if isinstance(t, np.ndarray):
return np.exp(mu) / (np.exp(mu) + (1 / t - 1) ** sigma)
else:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def _time_shift_linear(
self, mu: float, sigma: float, t: torch.Tensor | np.ndarray
) -> torch.Tensor | np.ndarray:
return mu / (mu + (1 / t - 1) ** sigma)
def add_noise(
self,
clean_latent: torch.Tensor,
noise: torch.Tensor,
timestep: torch.IntTensor,
) -> torch.Tensor:
"""
Args:
clean_latent: the clean latent with shape [B, C, H, W],
where B is batch_size or batch_size * num_frames
noise: the noise with shape [B, C, H, W]
timestep: the timestep with shape [1] or [bs * num_frames] or [bs, num_frames]
Returns:
the corrupted latent with shape [B, C, H, W]
"""
# If timestep is [bs, num_frames]
if timestep.ndim == 2:
timestep = timestep.flatten(0, 1)
assert timestep.numel() == clean_latent.shape[0]
elif timestep.ndim == 1:
# If timestep is [1]
if timestep.shape[0] == 1:
timestep = timestep.expand(clean_latent.shape[0])
else:
assert timestep.numel() == clean_latent.shape[0]
else:
raise ValueError(f"[add_noise] Invalid timestep shape: {timestep.shape}")
# timestep shape should be [B]
self.sigmas = self.sigmas.to(noise.device)
self.timesteps = self.timesteps.to(noise.device)
timestep_id = torch.argmin(
(self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1
)
sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1)
sample = (1 - sigma) * clean_latent + sigma * noise
return sample.type_as(noise)
def scale_model_input(
self, sample: torch.Tensor, timestep: int | None = None
) -> torch.Tensor:
return sample
def __len__(self) -> int:
return 0
EntryClass = FlowMatchEulerDiscreteScheduler
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
# Convert unipc for flow matching
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
from typing import Any
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput,
)
from diffusers.utils import deprecate
from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler
class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin, BaseScheduler):
"""
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
solver_order (`int`, default `2`):
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
unconditional sampling.
prediction_type (`str`, defaults to "flow_prediction"):
Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
the flow of the diffusion process.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
predict_x0 (`bool`, defaults to `True`):
Whether to use the updating algorithm on the predicted x0.
solver_type (`str`, default `bh2`):
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
otherwise.
lower_order_final (`bool`, default `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
disable_corrector (`list`, default `[]`):
Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
usually disabled during the first few steps.
solver_p (`SchedulerMixin`, default `None`):
Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
solver_order: int = 2,
prediction_type: str = "flow_prediction",
shift: float | None = 1.0,
use_dynamic_shifting=False,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
predict_x0: bool = True,
solver_type: str = "bh2",
lower_order_final: bool = True,
disable_corrector: tuple = (),
solver_p: SchedulerMixin = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
final_sigmas_type: str | None = "zero", # "zero", "sigma_min"
**kwargs,
):
if solver_type not in ["bh1", "bh2"]:
if solver_type in ["midpoint", "heun", "logrho"]:
self.register_to_config(solver_type="bh2")
else:
raise NotImplementedError(
f"{solver_type} is not implemented for {self.__class__}"
)
self.predict_x0 = predict_x0
# setable values
self.num_inference_steps: int | None = None
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[
::-1
].copy()
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
if not use_dynamic_shifting:
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
assert shift is not None
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
self.sigmas = sigmas
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.timesteps = sigmas * num_train_timesteps
self.num_train_timesteps = num_train_timesteps
self.model_outputs = [None] * solver_order
self.timestep_list: list[Any | None] = [None] * solver_order
self.lower_order_nums = 0
self.disable_corrector = list(disable_corrector)
self.solver_p = solver_p
self.last_sample = None
self._step_index: int | None = None
self._begin_index: int | None = None
BaseScheduler.__init__(self)
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
def set_shift(self, shift: float) -> None:
self.config.shift = shift
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
# Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
def set_timesteps(
self,
num_inference_steps: int | None = None,
device: str | torch.device = None,
sigmas: list[float] | None = None,
mu: float | None | None = None,
shift: float | None | None = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
Total number of the spacing of the time steps.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
if self.config.use_dynamic_shifting and mu is None:
raise ValueError(
" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
)
if sigmas is None:
assert num_inference_steps is not None
sigmas = np.linspace(
self.sigma_max, self.sigma_min, num_inference_steps + 1
).copy()[
:-1
] # pyright: ignore
if self.config.use_dynamic_shifting:
assert mu is not None
sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
else:
if shift is None:
shift = self.config.shift
assert isinstance(sigmas, np.ndarray)
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
timesteps = sigmas * self.config.num_train_timesteps
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(
np.float32
) # pyright: ignore
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(
device=device, dtype=torch.int64
)
self.num_inference_steps = len(timesteps)
self.model_outputs = [
None,
] * self.config.solver_order
self.lower_order_nums = 0
self.last_sample = None
if self.solver_p:
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = (
sample.float()
) # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = (
torch.clamp(sample, -s, s) / s
) # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def _sigma_to_alpha_sigma_t(self, sigma) -> tuple[Any, Any]:
return 1 - sigma, sigma
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def convert_model_output(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
r"""
Convert the model output to the corresponding type the UniPC algorithm needs.
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyword argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
if self.predict_x0:
if self.config.prediction_type == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
)
if self.config.thresholding:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
else:
if self.config.prediction_type == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
epsilon = sample - (1 - sigma_t) * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
)
if self.config.thresholding:
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
x0_pred = self._threshold_sample(x0_pred)
epsilon = model_output + x0_pred
return epsilon
def multistep_uni_p_bh_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
order: int | None = None, # pyright: ignore
**kwargs,
) -> torch.Tensor:
"""
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model at the current timestep.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
order (`int`):
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError(" missing `sample` as a required keyword argument")
if order is None:
if len(args) > 2:
order = args[2]
else:
raise ValueError(" missing `order` as a required keyword argument")
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
model_output_list = self.model_outputs
s0 = self.timestep_list[-1]
m0 = model_output_list[-1]
x = sample
if self.solver_p:
x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t
sigma_t, sigma_s0 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
) # pyright: ignore
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0
device = sample.device
rks = []
D1s: list[Any] | None = []
sigmas = self.sigmas.to(device=device)
for i in range(1, order):
si = self.step_index - i # pyright: ignore
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
assert mi is not None
D1s.append((mi - m0) / rk) # pyright: ignore
if len(rks) > 0:
rks = torch.stack(rks)
one = torch.ones(1, device=device, dtype=rks.dtype)
rks = torch.cat([rks, one])
else:
rks = torch.ones(1, device=device, dtype=h.dtype)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.stack(b)
if D1s is not None and len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
# for order 2, we use a simplified version
if order == 2:
rhos_p = 0.5 * torch.ones(1, dtype=x.dtype, device=device)
else:
assert isinstance(R, torch.Tensor)
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
else:
D1s = None
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum(
"k,bkc...->bc...", rhos_p, D1s
) # pyright: ignore
else:
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum(
"k,bkc...->bc...", rhos_p, D1s
) # pyright: ignore
else:
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
x_t = x_t.to(x.dtype)
return x_t
def multistep_uni_c_bh_update(
self,
this_model_output: torch.Tensor,
*args,
last_sample: torch.Tensor = None,
this_sample: torch.Tensor = None,
order: int | None = None, # pyright: ignore
**kwargs,
) -> torch.Tensor:
"""
One step for the UniC (B(h) version).
Args:
this_model_output (`torch.Tensor`):
The model outputs at `x_t`.
this_timestep (`int`):
The current timestep `t`.
last_sample (`torch.Tensor`):
The generated sample before the last predictor `x_{t-1}`.
this_sample (`torch.Tensor`):
The generated sample after the last predictor `x_{t}`.
order (`int`):
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
Returns:
`torch.Tensor`:
The corrected sample tensor at the current timestep.
"""
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
if last_sample is None:
if len(args) > 1:
last_sample = args[1]
else:
raise ValueError(" missing`last_sample` as a required keyword argument")
if this_sample is None:
if len(args) > 2:
this_sample = args[2]
else:
raise ValueError(" missing`this_sample` as a required keyword argument")
if order is None:
if len(args) > 3:
order = args[3]
else:
raise ValueError(" missing`order` as a required keyword argument")
if this_timestep is not None:
deprecate(
"this_timestep",
"1.0.0",
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
model_output_list = self.model_outputs
m0 = model_output_list[-1]
x = last_sample
x_t = this_sample
model_t = this_model_output
sigma_t, sigma_s0 = (
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
) # pyright: ignore
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0
device = this_sample.device
# Build rks and D1s fully on device to avoid any host-device sync
# Fast paths for small orders (common cases: 1 or 2)
if order == 1:
rks = torch.ones(1, device=device, dtype=h.dtype)
D1s = None
elif order == 2:
# order == 2 -> only one historical point is used
si = self.step_index - 2 # i = 1
mi = model_output_list[-2]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h # 0-dim tensor on device
# rks = [rk, 1.0] but keep it on device without list->tensor sync
rks = torch.stack((rk, torch.ones_like(rk)))
assert mi is not None
# D1s shape: (B, K=1, C, ...) to match later einsum over K
D1s = ((mi - m0) / rk).unsqueeze(1) # pyright: ignore
else:
rks_list = []
D1s_list = []
for i in range(1, order):
si = self.step_index - (i + 1)
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks_list.append(rk)
assert mi is not None
D1s_list.append((mi - m0) / rk) # pyright: ignore
# Append 1.0 as a device tensor to rks
rks = torch.stack(rks_list + [torch.ones_like(rks_list[0])])
D1s = torch.stack(D1s_list, dim=1) if len(D1s_list) > 0 else None
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
# Avoid torch.tensor(list_of_gpu_scalars) which syncs to host
b = torch.stack(b)
# D1s is already prepared above for order==2; remains None for order==1
# for order 1, we use a simplified version
if order == 1:
rhos_c = 0.5 * torch.ones(1, dtype=x.dtype, device=device)
elif order == 2:
# Manually solve the 2x2 linear system to avoid device synchronization from torch.linalg.solve
# R = [[1, 1], [rk, 1]], where rk = rks[0]
rk = rks[0]
det = 1 - rk
# Using Cramer's rule to solve for rhos_c = [x0, x1]
# x0 = (b0 - b1) / det
# x1 = (b1 - rk * b0) / det
rhos_c_0 = (b[0] - b[1]) / det
rhos_c_1 = (b[1] - rk * b[0]) / det
rhos_c = torch.stack([rhos_c_0, rhos_c_1])
else:
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
x_t = x_t.to(x.dtype)
return x_t
def index_for_timestep(self, timestep, schedule_timesteps=None) -> int:
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
step_index: int = indices[pos].item()
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep) -> None:
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.Tensor,
timestep: int | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
generator=None,
) -> SchedulerOutput | tuple:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the multistep UniPC.
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to call 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
use_corrector = (
self.step_index > 0
and self.step_index - 1 not in self.disable_corrector
and self.last_sample is not None # pyright: ignore
)
sample = sample.to(model_output.device)
model_output_convert = self.convert_model_output(model_output, sample=sample)
if use_corrector:
sample = self.multistep_uni_c_bh_update(
this_model_output=model_output_convert,
last_sample=self.last_sample,
this_sample=sample,
order=self.this_order,
)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.timestep_list[i] = self.timestep_list[i + 1]
self.model_outputs[-1] = model_output_convert
self.timestep_list[-1] = timestep # pyright: ignore
if self.config.lower_order_final:
this_order = min(
self.config.solver_order, len(self.timesteps) - self.step_index
) # pyright: ignore
else:
this_order = self.config.solver_order
self.this_order: int = min(
this_order, self.lower_order_nums + 1
) # warmup for multistep
assert self.this_order > 0
self.last_sample = sample
prev_sample = self.multistep_uni_p_bh_update(
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
sample=sample,
order=self.this_order,
)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# upon completion increase step index by one
assert self._step_index is not None
self._step_index += 1 # pyright: ignore
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
return sample
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(
device=original_samples.device, dtype=original_samples.dtype
)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(
original_samples.device, dtype=torch.float32
)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [
self.index_for_timestep(t, schedule_timesteps) for t in timesteps
]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
EntryClass = FlowUniPCMultistepScheduler
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import BaseOutput
from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class SelfForcingFlowMatchSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class SelfForcingFlowMatchScheduler(BaseScheduler, ConfigMixin, SchedulerMixin):
config_name = "scheduler_config.json"
order = 1
@register_to_config
def __init__(
self,
num_inference_steps=100,
num_train_timesteps=1000,
shift=3.0,
sigma_max=1.0,
sigma_min=0.003 / 1.002,
inverse_timesteps=False,
extra_one_step=False,
reverse_sigmas=False,
training=False,
):
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.inverse_timesteps = inverse_timesteps
self.extra_one_step = extra_one_step
self.reverse_sigmas = reverse_sigmas
self.set_timesteps(num_inference_steps, training=training)
def set_timesteps(
self,
num_inference_steps=100,
denoising_strength=1.0,
training=False,
return_dict=False,
**kwargs,
):
sigma_start = (
self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
)
if self.extra_one_step:
self.sigmas = torch.linspace(
sigma_start, self.sigma_min, num_inference_steps + 1
)[:-1]
else:
self.sigmas = torch.linspace(
sigma_start, self.sigma_min, num_inference_steps
)
if self.inverse_timesteps:
self.sigmas = torch.flip(self.sigmas, dims=[0])
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
if self.reverse_sigmas:
self.sigmas = 1 - self.sigmas
self.timesteps = self.sigmas * self.num_train_timesteps
if training:
x = self.timesteps
y = torch.exp(
-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2
)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing
def step(
self,
model_output: torch.FloatTensor,
timestep: torch.FloatTensor,
sample: torch.FloatTensor,
to_final=False,
return_dict=False,
**kwargs,
):
if timestep.ndim == 2:
timestep = timestep.flatten(0, 1)
elif timestep.ndim == 0:
# handles the case where timestep is a scalar, this occurs when we
# use this scheduler for ODE trajectory
timestep = timestep.unsqueeze(0)
self.sigmas = self.sigmas.to(model_output.device)
self.timesteps = self.timesteps.to(model_output.device)
timestep = timestep.to(model_output.device)
timestep_id = torch.argmin(
(self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1
)
sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1)
if to_final or (timestep_id + 1 >= len(self.timesteps)).any():
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
else:
sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1)
prev_sample = sample + model_output * (sigma_ - sigma)
if isinstance(prev_sample, torch.Tensor | float) and not return_dict:
return (prev_sample,)
return SelfForcingFlowMatchSchedulerOutput(prev_sample=prev_sample)
def add_noise(self, original_samples, noise, timestep):
"""
Diffusion forward corruption process.
Input:
- clean_latent: the clean latent with shape [B*T, C, H, W]
- noise: the noise with shape [B*T, C, H, W]
- timestep: the timestep with shape [B*T]
Output: the corrupted latent with shape [B*T, C, H, W]
"""
if timestep.ndim == 2:
timestep = timestep.flatten(0, 1)
self.sigmas = self.sigmas.to(noise.device)
self.timesteps = self.timesteps.to(noise.device)
timestep_id = torch.argmin(
(self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1
)
sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1)
sample = (1 - sigma) * original_samples + sigma * noise
return sample.type_as(noise)
def training_target(self, sample, noise, timestep):
target = noise - sample
return target
def training_weight(self, timestep):
"""
Input:
- timestep: the timestep with shape [B*T]
Output: the corresponding weighting [B*T]
"""
if timestep.ndim == 2:
timestep = timestep.flatten(0, 1)
self.linear_timesteps_weights = self.linear_timesteps_weights.to(
timestep.device
)
timestep_id = torch.argmin(
(self.timesteps.unsqueeze(1) - timestep.unsqueeze(0)).abs(), dim=0
)
weights = self.linear_timesteps_weights[timestep_id]
return weights
def scale_model_input(
self, sample: torch.Tensor, timestep: int | None = None
) -> torch.Tensor:
return sample
def set_shift(self, shift: float) -> None:
self.shift = shift
EntryClass = SelfForcingFlowMatchScheduler
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment