Unverified Commit 7bc1dae0 authored by Mick's avatar Mick Committed by GitHub
Browse files

WIP: initial multimodal-gen support (#12484)


Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarJiLi <leege233@gmail.com>
Co-authored-by: default avatarCHEN Xi <78632976+RubiaCx@users.noreply.github.com>
Co-authored-by: default avatarlaixin <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarSolitaryThinker <wlsaidhi@gmail.com>
Co-authored-by: default avatarjzhang38 <a1286225768@gmail.com>
Co-authored-by: default avatarBrianChen1129 <yongqichcd@gmail.com>
Co-authored-by: default avatarKevin Lin <42618777+kevin314@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarrlsu9 <r3su@ucsd.edu>
Co-authored-by: default avatarJinzhe Pan <48981407+eigensystem@users.noreply.github.com>
Co-authored-by: default avatarforeverpiano <pianoqwz@qq.com>
Co-authored-by: default avatarRandNMR73 <notomatthew31@gmail.com>
Co-authored-by: default avatarPorridgeSwim <yz3883@columbia.edu>
Co-authored-by: default avatarJiali Chen <90408393+gary-chenjl@users.noreply.github.com>
parent 4fe53e58
# SPDX-License-Identifier: Apache-2.0
from .vmoba import moba_attn_varlen, process_moba_input, process_moba_output
# SPDX-License-Identifier: Apache-2.0
# Adapt from https://github.com/KwaiVGI/VMoBA/blob/main/src/vmoba.py
import random
import time
from typing import Tuple
import torch
try:
from flash_attn import ( # Use the new flash attention function
flash_attn_varlen_func,
)
from flash_attn.flash_attn_interface import (
_flash_attn_varlen_backward,
_flash_attn_varlen_forward,
)
except ImportError:
def _unsupported(*args, **kwargs):
raise ImportError(
"flash-attn is not installed. Please install it, e.g., `pip install flash-attn`."
)
_flash_attn_varlen_forward = _unsupported
_flash_attn_varlen_backward = _unsupported
flash_attn_varlen_func = _unsupported
from functools import lru_cache
from einops import rearrange
@lru_cache(maxsize=16)
def calc_chunks(cu_seqlen, moba_chunk_size):
"""
Calculate chunk boundaries.
For vision tasks we include all chunks (even the last one which might be shorter)
so that every chunk can be selected.
"""
batch_sizes = cu_seqlen[1:] - cu_seqlen[:-1]
batch_num_chunk = (batch_sizes + (moba_chunk_size - 1)) // moba_chunk_size
cu_num_chunk = torch.ones(
batch_num_chunk.numel() + 1,
device=cu_seqlen.device,
dtype=batch_num_chunk.dtype,
)
cu_num_chunk[1:] = batch_num_chunk.cumsum(dim=0)
num_chunk = cu_num_chunk[-1]
chunk_sizes = torch.full(
(num_chunk + 1,), moba_chunk_size, dtype=torch.int32, device=cu_seqlen.device
)
chunk_sizes[0] = 0
batch_last_chunk_size = batch_sizes - (batch_num_chunk - 1) * moba_chunk_size
chunk_sizes[cu_num_chunk[1:]] = batch_last_chunk_size
cu_chunk = chunk_sizes.cumsum(dim=-1, dtype=torch.int32)
chunk_to_batch = torch.zeros(
(num_chunk,), dtype=torch.int32, device=cu_seqlen.device
)
chunk_to_batch[cu_num_chunk[1:-1]] = 1
chunk_to_batch = chunk_to_batch.cumsum(dim=0, dtype=torch.int32)
# Do not filter out any chunk
filtered_chunk_indices = torch.arange(
num_chunk, device=cu_seqlen.device, dtype=torch.int32
)
num_filtered_chunk = num_chunk
return cu_chunk, filtered_chunk_indices, num_filtered_chunk, chunk_to_batch
# --- Threshold Selection Helper Functions ---
def _select_threshold_query_head(
gate: torch.Tensor,
valid_gate_mask: torch.Tensor,
gate_self_chunk_mask: torch.Tensor,
simsum_threshold: float,
) -> torch.Tensor:
"""
Selects chunks for each <query, head> pair based on threshold.
Normalization and sorting happen along the chunk dimension (dim=0).
"""
C, H, S = gate.shape
eps = 1e-6
# LSE‐style normalization per <head, query> (across chunks)
gate_masked = torch.where(valid_gate_mask, gate, -torch.inf) # Use -inf for max
gate_min_val = torch.where(valid_gate_mask, gate, torch.inf) # Use +inf for min
row_min = gate_min_val.amin(dim=0) # (H, S)
row_max = gate_masked.amax(dim=0) # (H, S)
denom = row_max - row_min
denom = torch.where(
denom <= eps, torch.ones_like(denom), denom
) # avoid divide‑by‑zero
gate_norm = (gate - row_min.unsqueeze(0)) / denom.unsqueeze(0)
gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0) # (C, H, S)
# 1) pull out the self‐chunk’s normalized weight for each <head,seq>
self_norm = (gate_norm * gate_self_chunk_mask).sum(dim=0) # (H, S)
# 2) compute how much more normalized weight we need beyond self
total_norm_sum = gate_norm.sum(dim=0) # (H, S)
remain_ratio = simsum_threshold - self_norm / (total_norm_sum + eps) # (H, S)
remain_ratio = torch.clamp(
remain_ratio, min=0.0
) # if already ≥ thresh, no extra needed
# 3) zero out the self‐chunk in a copy, so we only sort “others”
others_norm = gate_norm.clone()
others_norm[gate_self_chunk_mask] = 0.0
# 4) sort the other chunks by descending norm, per <head,seq>
sorted_norm, sorted_idx = torch.sort(
others_norm, descending=True, dim=0
) # (C, H, S)
# 5) cumulative‑sum the sorted norms per <head,seq>
cumsum_others = sorted_norm.cumsum(dim=0) # (C, H, S)
# 6) for each <head,seq>, find the smallest k where cumsum_ratio ≥ remain_ratio
ratio = cumsum_others / (total_norm_sum.unsqueeze(0) + eps) # (C, H, S)
cond = ratio >= remain_ratio.unsqueeze(0) # (C, H, S) boolean mask
any_cond = cond.any(dim=0) # (H, S)
# Find the index of the first True value along dim 0. If none, use C-1.
cutoff = torch.where(
any_cond,
cond.float().argmax(dim=0),
torch.full_like(any_cond, fill_value=C - 1),
) # (H, S)
# 7) build a mask in sorted order up to that cutoff
idx_range = torch.arange(C, device=gate.device).view(-1, 1, 1) # (C, 1, 1)
sorted_mask = idx_range <= cutoff.unsqueeze(0) # (C, H, S)
# 8) scatter it back to original chunk order
others_mask = torch.zeros_like(gate, dtype=torch.bool)
others_mask.scatter_(0, sorted_idx, sorted_mask)
# 9) finally, include every self‐chunk plus all selected others
final_gate_mask = valid_gate_mask & (others_mask | gate_self_chunk_mask)
return final_gate_mask
def _select_threshold_block(
gate: torch.Tensor,
valid_gate_mask: torch.Tensor,
gate_self_chunk_mask: torch.Tensor,
simsum_threshold: float,
) -> torch.Tensor:
"""
Selects <query, head> pairs for each block based on threshold.
Normalization and sorting happen across the head and sequence dimensions (dim=1, 2).
"""
C, H, S = gate.shape
HS = H * S
eps = 1e-6
# LSE‐style normalization per block (across heads and queries)
gate_masked = torch.where(valid_gate_mask, gate, -torch.inf) # Use -inf for max
gate_min_val = torch.where(valid_gate_mask, gate, torch.inf) # Use +inf for min
block_max = gate_masked.amax(dim=(1, 2), keepdim=True) # (C, 1, 1)
block_min = gate_min_val.amin(dim=(1, 2), keepdim=True) # (C, 1, 1)
block_denom = block_max - block_min
block_denom = torch.where(
block_denom <= eps, torch.ones_like(block_denom), block_denom
) # (C, 1, 1)
gate_norm = (gate - block_min) / block_denom # (C, H, S)
gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0) # (C, H, S)
# 1) identify normalized weights of entries that *are* self-chunks (from query perspective)
self_norm_entries = gate_norm * gate_self_chunk_mask # (C, H, S)
# Sum these weights *per block*
self_norm_sum_per_block = self_norm_entries.sum(dim=(1, 2)) # (C,)
# 2) compute how much more normalized weight each block needs beyond its self-chunk contributions
total_norm_sum_per_block = gate_norm.sum(dim=(1, 2)) # (C,)
remain_ratio = simsum_threshold - self_norm_sum_per_block / (
total_norm_sum_per_block + eps
) # (C,)
remain_ratio = torch.clamp(remain_ratio, min=0.0) # (C,)
# 3) zero out the self‐chunk entries in a copy, so we only sort “others”
others_norm = gate_norm.clone()
others_norm[gate_self_chunk_mask] = 0.0 # Zero out self entries
# 4) sort the other <head, seq> pairs by descending norm, per block
others_flat = others_norm.contiguous().view(C, HS) # (C, H*S)
sorted_others_flat, sorted_indices_flat = torch.sort(
others_flat, dim=1, descending=True
) # (C, H*S)
# 5) cumulative‑sum the sorted norms per block
cumsum_others_flat = sorted_others_flat.cumsum(dim=1) # (C, H*S)
# 6) for each block, find the smallest k where cumsum_ratio ≥ remain_ratio
ratio_flat = cumsum_others_flat / (
total_norm_sum_per_block.unsqueeze(1) + eps
) # (C, H*S)
cond_flat = ratio_flat >= remain_ratio.unsqueeze(1) # (C, H*S) boolean mask
any_cond = cond_flat.any(dim=1) # (C,)
# Find the index of the first True value along dim 1. If none, use HS-1.
cutoff_flat = torch.where(
any_cond,
cond_flat.float().argmax(dim=1),
torch.full_like(any_cond, fill_value=HS - 1),
) # (C,)
# 7) build a mask in sorted order up to that cutoff per block
idx_range_flat = torch.arange(HS, device=gate.device).unsqueeze(0) # (1, H*S)
sorted_mask_flat = idx_range_flat <= cutoff_flat.unsqueeze(1) # (C, H*S)
# 8) scatter it back to original <head, seq> order per block
others_mask_flat = torch.zeros_like(others_flat, dtype=torch.bool) # (C, H*S)
others_mask_flat.scatter_(1, sorted_indices_flat, sorted_mask_flat)
others_mask = others_mask_flat.view(C, H, S) # (C, H, S)
# 9) finally, include every self‐chunk entry plus all selected others
final_gate_mask = valid_gate_mask & (others_mask | gate_self_chunk_mask)
return final_gate_mask
def _select_threshold_overall(
gate: torch.Tensor,
valid_gate_mask: torch.Tensor,
gate_self_chunk_mask: torch.Tensor,
simsum_threshold: float,
) -> torch.Tensor:
"""
Selects <chunk, query, head> triplets globally based on threshold.
Normalization and sorting happen across all valid entries.
"""
C, H, S = gate.shape
CHS = C * H * S
eps = 1e-6
# LSE‐style normalization globally across all valid entries
gate_masked = torch.where(valid_gate_mask, gate, -torch.inf) # Use -inf for max
gate_min_val = torch.where(valid_gate_mask, gate, torch.inf) # Use +inf for min
overall_max = gate_masked.max() # scalar
overall_min = gate_min_val.min() # scalar
overall_denom = overall_max - overall_min
overall_denom = torch.where(
overall_denom <= eps,
torch.tensor(1.0, device=gate.device, dtype=gate.dtype),
overall_denom,
)
gate_norm = (gate - overall_min) / overall_denom # (C, H, S)
gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0) # (C, H, S)
# 1) identify normalized weights of entries that *are* self-chunks
self_norm_entries = gate_norm * gate_self_chunk_mask # (C, H, S)
# Sum these weights globally
self_norm_sum_overall = self_norm_entries.sum() # scalar
# 2) compute how much more normalized weight is needed globally beyond self-chunk contributions
total_norm_sum_overall = gate_norm.sum() # scalar
remain_ratio = simsum_threshold - self_norm_sum_overall / (
total_norm_sum_overall + eps
) # scalar
remain_ratio = torch.clamp(remain_ratio, min=0.0) # scalar
# 3) zero out the self‐chunk entries in a copy, so we only sort “others”
others_norm = gate_norm.clone()
others_norm[gate_self_chunk_mask] = 0.0 # Zero out self entries
# 4) sort all other entries by descending norm, globally
others_flat = others_norm.flatten() # (C*H*S,)
valid_others_mask_flat = (
valid_gate_mask.flatten() & ~gate_self_chunk_mask.flatten()
) # Mask for valid, non-self entries
# Only sort the valid 'other' entries
valid_others_indices = torch.where(valid_others_mask_flat)[0]
valid_others_values = others_flat[valid_others_indices]
sorted_others_values, sort_perm = torch.sort(
valid_others_values, descending=True
) # (N_valid_others,)
sorted_original_indices = valid_others_indices[
sort_perm
] # Original indices in C*H*S space, sorted by value
# 5) cumulative‑sum the sorted valid 'other' norms globally
cumsum_others_values = sorted_others_values.cumsum(dim=0) # (N_valid_others,)
# 6) find the smallest k where cumsum_ratio ≥ remain_ratio globally
ratio_values = cumsum_others_values / (
total_norm_sum_overall + eps
) # (N_valid_others,)
cond_values = ratio_values >= remain_ratio # (N_valid_others,) boolean mask
any_cond = cond_values.any() # scalar
# Find the index of the first True value in the *sorted* list. If none, use all valid others.
cutoff_idx_in_sorted = torch.where(
any_cond,
cond_values.float().argmax(dim=0),
torch.tensor(
len(sorted_others_values) - 1, device=gate.device, dtype=torch.long
),
)
# 7) build a mask selecting the top-k others based on the cutoff
# Select the original indices corresponding to the top entries in the sorted list
selected_other_indices = sorted_original_indices[: cutoff_idx_in_sorted + 1]
# 8) create the mask in the original flat shape
others_mask_flat = torch.zeros_like(others_flat, dtype=torch.bool) # (C*H*S,)
if selected_other_indices.numel() > 0: # Check if any 'other' indices were selected
others_mask_flat[selected_other_indices] = True
others_mask = others_mask_flat.view(C, H, S) # (C, H, S)
# 9) finally, include every self‐chunk entry plus all selected others
final_gate_mask = valid_gate_mask & (others_mask | gate_self_chunk_mask)
return final_gate_mask
def _select_threshold_head_global(
gate: torch.Tensor,
valid_gate_mask: torch.Tensor,
gate_self_chunk_mask: torch.Tensor,
simsum_threshold: float,
) -> torch.Tensor:
"""
Selects <chunk, query> globally for each head based on threshold.
"""
C, H, S = gate.shape
eps = 1e-6
# 1) LSE‐style normalization per head (across chunks and sequence dims)
gate_masked = torch.where(valid_gate_mask, gate, -torch.inf)
gate_min_val = torch.where(valid_gate_mask, gate, torch.inf)
max_per_head = gate_masked.amax(dim=(0, 2), keepdim=True) # (1, H, 1)
min_per_head = gate_min_val.amin(dim=(0, 2), keepdim=True) # (1, H, 1)
denom = max_per_head - min_per_head
denom = torch.where(denom <= eps, torch.ones_like(denom), denom)
gate_norm = (gate - min_per_head) / denom
gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0) # (C, H, S)
# 2) sum normalized self‐chunk contributions per head
self_norm_sum = (gate_norm * gate_self_chunk_mask).sum(dim=(0, 2)) # (H,)
# 3) total normalized sum per head
total_norm_sum = gate_norm.sum(dim=(0, 2)) # (H,)
# 4) how much more normalized weight needed per head
remain_ratio = simsum_threshold - self_norm_sum / (total_norm_sum + eps) # (H,)
remain_ratio = torch.clamp(remain_ratio, min=0.0)
# 5) zero out self‐chunk entries to focus on "others"
others_norm = gate_norm.clone()
others_norm[gate_self_chunk_mask] = 0.0 # (C, H, S)
# 6) flatten chunk and sequence dims, per head
CS = C * S
others_flat = others_norm.permute(1, 0, 2).reshape(H, CS) # (H, C*S)
valid_flat = (
(valid_gate_mask & ~gate_self_chunk_mask).permute(1, 0, 2).reshape(H, CS)
) # (H, C*S)
# 7) vectorized selection of “others” per head
masked_flat = torch.where(valid_flat, others_flat, torch.zeros_like(others_flat))
sorted_vals, sorted_idx = torch.sort(
masked_flat, dim=1, descending=True
) # (H, C*S)
cumsum_vals = sorted_vals.cumsum(dim=1) # (H, C*S)
ratio_vals = cumsum_vals / (total_norm_sum.unsqueeze(1) + eps) # (H, C*S)
cond = ratio_vals >= remain_ratio.unsqueeze(1) # (H, C*S)
has_cutoff = cond.any(dim=1) # (H,)
default = torch.full((H,), CS - 1, device=gate.device, dtype=torch.long)
cutoff = torch.where(has_cutoff, cond.float().argmax(dim=1), default) # (H,)
idx_range = torch.arange(CS, device=gate.device).unsqueeze(0) # (1, C*S)
sorted_mask = idx_range <= cutoff.unsqueeze(1) # (H, C*S)
selected_flat = torch.zeros_like(valid_flat) # (H, C*S)
selected_flat.scatter_(1, sorted_idx, sorted_mask) # (H, C*S)
# 8) reshape selection mask back to (C, H, S)
others_mask = selected_flat.reshape(H, C, S).permute(1, 0, 2) # (C, H, S)
# 9) include self‐chunks plus selected others, and obey valid mask
final_gate_mask = valid_gate_mask & (gate_self_chunk_mask | others_mask)
return final_gate_mask
class MixedAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
self_attn_cu_seqlen,
moba_q,
moba_kv,
moba_cu_seqlen_q,
moba_cu_seqlen_kv,
max_seqlen,
moba_chunk_size,
moba_q_sh_indices,
):
ctx.max_seqlen = max_seqlen
ctx.moba_chunk_size = moba_chunk_size
ctx.softmax_scale = softmax_scale = q.shape[-1] ** (-0.5)
# Non-causal self-attention branch
# return out, softmax_lse, S_dmask, rng_state
self_attn_out_sh, self_attn_lse_hs, _, _ = _flash_attn_varlen_forward(
q=q,
k=k,
v=v,
cu_seqlens_q=self_attn_cu_seqlen,
cu_seqlens_k=self_attn_cu_seqlen,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
softmax_scale=softmax_scale,
causal=False,
dropout_p=0.0,
)
# MOBA attention branch (non-causal)
moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward(
q=moba_q,
k=moba_kv[:, 0],
v=moba_kv[:, 1],
cu_seqlens_q=moba_cu_seqlen_q,
cu_seqlens_k=moba_cu_seqlen_kv,
max_seqlen_q=max_seqlen,
max_seqlen_k=moba_chunk_size,
softmax_scale=softmax_scale,
causal=False,
dropout_p=0.0,
)
self_attn_lse_sh = self_attn_lse_hs.t().contiguous()
moba_attn_lse = moba_attn_lse_hs.t().contiguous()
output = torch.zeros(
(q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
)
output_2d = output.view(-1, q.shape[2])
max_lse_1d = self_attn_lse_sh.view(-1)
max_lse_1d = max_lse_1d.index_reduce(
0, moba_q_sh_indices, moba_attn_lse.view(-1), "amax"
)
self_attn_lse_sh = self_attn_lse_sh - max_lse_1d.view_as(self_attn_lse_sh)
moba_attn_lse = (
moba_attn_lse.view(-1)
.sub(max_lse_1d.index_select(0, moba_q_sh_indices))
.reshape_as(moba_attn_lse)
)
mixed_attn_se_sh = self_attn_lse_sh.exp()
moba_attn_se = moba_attn_lse.exp()
mixed_attn_se_sh.view(-1).index_add_(
0, moba_q_sh_indices, moba_attn_se.view(-1)
)
mixed_attn_lse_sh = mixed_attn_se_sh.log()
# Combine self-attention output
factor = (self_attn_lse_sh - mixed_attn_lse_sh).exp() # [S, H]
self_attn_out_sh = self_attn_out_sh * factor.unsqueeze(-1)
output_2d += self_attn_out_sh.reshape_as(output_2d)
# Combine MOBA attention output
mixed_attn_lse = (
mixed_attn_lse_sh.view(-1)
.index_select(0, moba_q_sh_indices)
.view_as(moba_attn_lse)
)
factor = (moba_attn_lse - mixed_attn_lse).exp() # [S, H]
moba_attn_out = moba_attn_out * factor.unsqueeze(-1)
raw_attn_out = moba_attn_out.view(-1, moba_attn_out.shape[-1])
output_2d.index_add_(0, moba_q_sh_indices, raw_attn_out)
output = output.to(q.dtype)
mixed_attn_lse_sh = mixed_attn_lse_sh + max_lse_1d.view_as(mixed_attn_se_sh)
ctx.save_for_backward(
output,
mixed_attn_lse_sh,
q,
k,
v,
self_attn_cu_seqlen,
moba_q,
moba_kv,
moba_cu_seqlen_q,
moba_cu_seqlen_kv,
moba_q_sh_indices,
)
return output
@staticmethod
def backward(ctx, d_output):
max_seqlen = ctx.max_seqlen
moba_chunk_size = ctx.moba_chunk_size
softmax_scale = ctx.softmax_scale
(
output,
mixed_attn_vlse_sh,
q,
k,
v,
self_attn_cu_seqlen,
moba_q,
moba_kv,
moba_cu_seqlen_q,
moba_cu_seqlen_kv,
moba_q_sh_indices,
) = ctx.saved_tensors
d_output = d_output.contiguous()
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
_ = _flash_attn_varlen_backward(
dout=d_output,
q=q,
k=k,
v=v,
out=output,
softmax_lse=mixed_attn_vlse_sh.t().contiguous(),
dq=dq,
dk=dk,
dv=dv,
cu_seqlens_q=self_attn_cu_seqlen,
cu_seqlens_k=self_attn_cu_seqlen,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
softmax_scale=softmax_scale,
causal=False,
dropout_p=0.0,
softcap=0.0,
alibi_slopes=None,
deterministic=True,
window_size_left=-1,
window_size_right=-1,
)
headdim = q.shape[-1]
d_moba_output = (
d_output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1)
)
moba_output = (
output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1)
)
mixed_attn_vlse = (
mixed_attn_vlse_sh.view(-1).index_select(0, moba_q_sh_indices).view(1, -1)
)
dmq = torch.empty_like(moba_q)
dmkv = torch.empty_like(moba_kv)
_ = _flash_attn_varlen_backward(
dout=d_moba_output,
q=moba_q,
k=moba_kv[:, 0],
v=moba_kv[:, 1],
out=moba_output,
softmax_lse=mixed_attn_vlse,
dq=dmq,
dk=dmkv[:, 0],
dv=dmkv[:, 1],
cu_seqlens_q=moba_cu_seqlen_q,
cu_seqlens_k=moba_cu_seqlen_kv,
max_seqlen_q=max_seqlen,
max_seqlen_k=moba_chunk_size,
softmax_scale=softmax_scale,
causal=False,
dropout_p=0.0,
softcap=0.0,
alibi_slopes=None,
deterministic=True,
window_size_left=-1,
window_size_right=-1,
)
return dq, dk, dv, None, dmq, dmkv, None, None, None, None, None
def moba_attn_varlen(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
moba_chunk_size: int,
moba_topk: int,
select_mode: str = "threshold", # "topk" or "threshold"
simsum_threshold: float = 0.25,
threshold_type: str = "query_head",
) -> torch.Tensor:
"""
Accelerated MOBA attention for vision tasks with proper LSE normalization.
This version:
- Splits KV into chunks.
- For each query head, selects the top-k relevant KV chunks (including the self chunk)
by amplifying the diagonal (self-chunk) logits.
- Aggregates the attention outputs from the selected chunks using a log-sum-exp
reduction so that attending to each query over the selected chunks is equivalent
to the original algorithm.
"""
# Stack keys and values.
kv = torch.stack((k, v), dim=1)
seqlen, num_head, head_dim = q.shape
# Compute chunk boundaries.
cu_chunk, filtered_chunk_indices, num_filtered_chunk, chunk_to_batch = calc_chunks(
cu_seqlens, moba_chunk_size
)
self_attn_cu_seqlen = cu_chunk
# Update top-k selection to include the self chunk.
moba_topk = min(moba_topk, num_filtered_chunk)
# --- Build filtered KV from chunks ---
chunk_starts = cu_chunk[filtered_chunk_indices] # [num_filtered_chunk]
chunk_ends = cu_chunk[filtered_chunk_indices + 1] # [num_filtered_chunk]
chunk_lengths = chunk_ends - chunk_starts # [num_filtered_chunk]
max_chunk_len = int(chunk_lengths.max().item())
range_tensor = torch.arange(
max_chunk_len, device=kv.device, dtype=chunk_starts.dtype
).unsqueeze(0)
indices = chunk_starts.unsqueeze(1) + range_tensor
indices = torch.clamp(indices, max=kv.shape[0] - 1)
valid_mask = range_tensor < chunk_lengths.unsqueeze(1)
gathered = kv[indices.view(-1)].view(
num_filtered_chunk, max_chunk_len, *kv.shape[1:]
)
gathered = gathered * valid_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).type_as(
gathered
)
# Compute key_gate_weight over valid tokens.
key_values = gathered[
:, :, 0
].float() # [num_filtered_chunk, max_chunk_len, num_head, head_dim]
valid_mask_exp = valid_mask.unsqueeze(-1).unsqueeze(-1)
key_sum = (key_values * valid_mask_exp).sum(dim=1)
divisor = valid_mask.sum(dim=1).unsqueeze(-1).unsqueeze(-1)
key_gate_weight = key_sum / divisor # [num_filtered_chunk, num_head, head_dim]
# Compute gate logits between key_gate_weight and queries.
q_float = q.float()
# gate = torch.einsum("nhd,shd->nhs", key_gate_weight, q_float) # [num_filtered_chunk, num_head, seqlen]
gate = torch.bmm(
key_gate_weight.permute(1, 0, 2), q_float.permute(1, 0, 2).transpose(1, 2)
).permute(1, 0, 2)
# Amplify the diagonal (self chunk) contributions.
gate_seq_idx = (
torch.arange(seqlen, device=q.device, dtype=torch.int32)
.unsqueeze(0)
.expand(num_filtered_chunk, seqlen)
)
chunk_start = cu_chunk[filtered_chunk_indices] # [num_filtered_chunk]
chunk_end = cu_chunk[filtered_chunk_indices + 1] # [num_filtered_chunk]
gate_self_chunk_mask = (
(
(gate_seq_idx >= chunk_start.unsqueeze(1))
& (gate_seq_idx < chunk_end.unsqueeze(1))
)
.unsqueeze(1)
.expand(-1, num_head, -1)
)
amplification_factor = 1e9 # Example factor; adjust as needed.
origin_gate = gate.clone()
gate = gate.clone()
if select_mode == "topk":
gate[gate_self_chunk_mask] += amplification_factor
# Exclude positions that are outside the valid batch boundaries.
batch_starts = cu_seqlens[chunk_to_batch[filtered_chunk_indices]]
batch_ends = cu_seqlens[chunk_to_batch[filtered_chunk_indices] + 1]
gate_batch_start_mask = gate_seq_idx < batch_starts.unsqueeze(1)
gate_batch_end_mask = gate_seq_idx >= batch_ends.unsqueeze(1)
gate_inf_mask = gate_batch_start_mask | gate_batch_end_mask
gate.masked_fill_(gate_inf_mask.unsqueeze(1), -float("inf"))
if select_mode == "topk":
# We amplify self‐chunk in gate already, so self entries will rank highest.
valid_gate_mask = gate != -float("inf")
if threshold_type == "query_head":
# === per‐<head,seq> top-k across chunks (original behavior) ===
# gate: (C, H, S)
_, gate_topk_idx = torch.topk(
gate, k=moba_topk, dim=0, largest=True, sorted=False
)
gate_idx_mask = torch.zeros_like(gate, dtype=torch.bool)
gate_idx_mask.scatter_(0, gate_topk_idx, True)
gate_mask = valid_gate_mask & gate_idx_mask
elif threshold_type == "overall":
# === global top-k across all (chunk, head, seq) entries ===
C, H, S = gate.shape
flat_gate = gate.flatten()
flat_mask = valid_gate_mask.flatten()
flat_gate_masked = torch.where(flat_mask, flat_gate, -float("inf"))
# pick topk global entries
vals, idx = torch.topk(
flat_gate_masked, k=moba_topk * H * S, largest=True, sorted=False
)
others_mask_flat = torch.zeros_like(flat_mask, dtype=torch.bool)
others_mask_flat[idx] = True
gate_mask = (valid_gate_mask.flatten() & others_mask_flat).view(gate.shape)
elif threshold_type == "head_global":
# per-head top-k across all chunks and sequence positions
C, H, S = gate.shape
CS = C * S
flat_gate = gate.permute(1, 0, 2).reshape(H, CS)
flat_valid = valid_gate_mask.permute(1, 0, 2).reshape(H, CS)
flat_gate_masked = torch.where(
flat_valid, flat_gate, torch.full_like(flat_gate, -float("inf"))
)
# pick top-k indices per head
_, topk_idx = torch.topk(
flat_gate_masked, k=moba_topk * S, dim=1, largest=True, sorted=False
)
gate_idx_flat = torch.zeros_like(flat_valid, dtype=torch.bool)
gate_idx_flat.scatter_(1, topk_idx, True)
gate_mask = gate_idx_flat.reshape(H, C, S).permute(1, 0, 2)
else:
raise ValueError(
f"Invalid threshold_type for topk: {threshold_type}. "
"Choose 'query_head', 'block', or 'overall'."
)
elif select_mode == "threshold":
# Delegate to the specific thresholding function
valid_gate_mask = gate != -float("inf") # (num_chunk, num_head, seqlen)
if threshold_type == "query_head":
gate_mask = _select_threshold_query_head(
gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold
)
elif threshold_type == "block":
gate_mask = _select_threshold_block(
gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold
)
elif threshold_type == "overall":
gate_mask = _select_threshold_overall(
gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold
)
elif threshold_type == "head_global":
gate_mask = _select_threshold_head_global(
gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold
)
else:
raise ValueError(
f"Invalid threshold_type: {threshold_type}. Choose 'query_head', 'block', or 'overall'."
)
else:
raise ValueError(
f"Invalid select_mode: {select_mode}. Choose 'topk' or 'threshold'."
)
# eliminate self_chunk in MoBA branch
gate_mask = gate_mask & ~gate_self_chunk_mask
# if gate_mask is all false, perform flash_attn instead
if gate_mask.sum() == 0:
return flash_attn_varlen_func(
q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=False
)
# Determine which query positions are selected.
# nonzero_indices has shape [N, 3] where each row is [chunk_index, head_index, seq_index].
moba_q_indices = gate_mask.reshape(gate_mask.shape[0], -1).nonzero(as_tuple=True)[
-1
] # [(h s k)]
moba_q_sh_indices = (moba_q_indices % seqlen) * num_head + (
moba_q_indices // seqlen
)
moba_q = (
rearrange(q, "s h d -> (h s) d").index_select(0, moba_q_indices).unsqueeze(1)
)
# Build cumulative sequence lengths for the selected queries.
moba_seqlen_q = gate_mask.sum(dim=-1).flatten()
q_zero_mask = moba_seqlen_q == 0
valid_expert_mask = ~q_zero_mask
if q_zero_mask.sum() > 0:
moba_seqlen_q = moba_seqlen_q[valid_expert_mask]
moba_cu_seqlen_q = torch.cat(
(
torch.tensor([0], device=q.device, dtype=moba_seqlen_q.dtype),
moba_seqlen_q.cumsum(dim=0),
),
dim=0,
).to(torch.int32)
# Rearrange gathered KV for the MOBA branch.
experts_tensor = rearrange(gathered, "nc cl two h d -> (nc h) cl two d")
valid_expert_lengths = (
chunk_lengths.unsqueeze(1)
.expand(num_filtered_chunk, num_head)
.reshape(-1)
.to(torch.int32)
)
if q_zero_mask.sum() > 0:
experts_tensor = experts_tensor[valid_expert_mask]
valid_expert_lengths = valid_expert_lengths[valid_expert_mask]
seq_range = torch.arange(
experts_tensor.shape[1], device=experts_tensor.device
).unsqueeze(0)
mask = seq_range < valid_expert_lengths.unsqueeze(1)
moba_kv = experts_tensor[mask] # Shape: ((nc h cl_valid) two d)
moba_kv = moba_kv.unsqueeze(2) # Shape: ((nc h cl_valid) two 1 d)
moba_cu_seqlen_kv = torch.cat(
[
torch.zeros(1, device=experts_tensor.device, dtype=torch.int32),
valid_expert_lengths.cumsum(dim=0),
],
dim=0,
).to(torch.int32)
assert (
moba_cu_seqlen_kv.shape == moba_cu_seqlen_q.shape
), f"Mismatch between moba_cu_seqlen_kv.shape and moba_cu_seqlen_q.shape: {moba_cu_seqlen_kv.shape} vs {moba_cu_seqlen_q.shape}"
return MixedAttention.apply(
q,
k,
v,
self_attn_cu_seqlen,
moba_q,
moba_kv,
moba_cu_seqlen_q,
moba_cu_seqlen_kv,
max_seqlen,
moba_chunk_size,
moba_q_sh_indices,
)
def process_moba_input(
x,
patch_resolution,
chunk_size,
):
"""
Process inputs for the attention function.
Args:
x (torch.Tensor): Input tensor with shape [batch_size, num_patches, num_heads, head_dim].
patch_resolution (tuple): Tuple containing the patch resolution (t, h, w).
chunk_size (int): Size of the chunk. (maybe tuple or int, according to chunk type)
Returns:
torch.Tensor: Processed input tensor.
"""
if isinstance(chunk_size, float) or isinstance(chunk_size, int):
moba_chunk_size = int(chunk_size * patch_resolution[1] * patch_resolution[2])
else:
assert isinstance(
chunk_size, (Tuple, list)
), f"chunk_size should be a tuple, list, or int, now it is: {type(chunk_size)}"
if len(chunk_size) == 2:
assert (
patch_resolution[1] % chunk_size[0] == 0
and patch_resolution[2] % chunk_size[1] == 0
), f"spatial patch_resolution {patch_resolution[1:]} should be divisible by 2d chunk_size {chunk_size}"
nch, ncw = (
patch_resolution[1] // chunk_size[0],
patch_resolution[2] // chunk_size[1],
)
x = rearrange(
x,
"b (t nch ch ncw cw) n d -> b (nch ncw t ch cw) n d",
t=patch_resolution[0],
nch=nch,
ncw=ncw,
ch=chunk_size[0],
cw=chunk_size[1],
)
moba_chunk_size = patch_resolution[0] * chunk_size[0] * chunk_size[1]
elif len(chunk_size) == 3:
assert (
patch_resolution[0] % chunk_size[0] == 0
and patch_resolution[1] % chunk_size[1] == 0
and patch_resolution[2] % chunk_size[2] == 0
), f"patch_resolution {patch_resolution} should be divisible by 3d chunk_size {chunk_size}"
nct, nch, ncw = (
patch_resolution[0] // chunk_size[0],
patch_resolution[1] // chunk_size[1],
patch_resolution[2] // chunk_size[2],
)
x = rearrange(
x,
"b (nct ct nch ch ncw cw) n d -> b (nct nch ncw ct ch cw) n d",
nct=nct,
nch=nch,
ncw=ncw,
ct=chunk_size[0],
ch=chunk_size[1],
cw=chunk_size[2],
)
moba_chunk_size = chunk_size[0] * chunk_size[1] * chunk_size[2]
else:
raise ValueError(
f"chunk_size should be a int, or a tuple of length 2 or 3, now it is: {len(chunk_size)}"
)
return x, moba_chunk_size
def process_moba_output(
x,
patch_resolution,
chunk_size,
):
if isinstance(chunk_size, float) or isinstance(chunk_size, int):
pass
elif len(chunk_size) == 2:
x = rearrange(
x,
"b (nch ncw t ch cw) n d -> b (t nch ch ncw cw) n d",
nch=patch_resolution[1] // chunk_size[0],
ncw=patch_resolution[2] // chunk_size[1],
t=patch_resolution[0],
ch=chunk_size[0],
cw=chunk_size[1],
)
elif len(chunk_size) == 3:
x = rearrange(
x,
"b (nct nch ncw ct ch cw) n d -> b (nct ct nch ch ncw cw) n d",
nct=patch_resolution[0] // chunk_size[0],
nch=patch_resolution[1] // chunk_size[1],
ncw=patch_resolution[2] // chunk_size[2],
ct=chunk_size[0],
ch=chunk_size[1],
cw=chunk_size[2],
)
return x
# TEST
def generate_data(batch_size, seqlen, num_head, head_dim, dtype):
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
device = torch.cuda.current_device()
q = torch.randn((batch_size, seqlen, num_head, head_dim), requires_grad=True).to(
dtype=dtype, device="cuda"
)
k = torch.randn((batch_size, seqlen, num_head, head_dim), requires_grad=True).to(
dtype=dtype, device="cuda"
)
v = torch.randn((batch_size, seqlen, num_head, head_dim), requires_grad=True).to(
dtype=dtype, device="cuda"
)
print(f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}")
cu_seqlens = torch.arange(
0, q.shape[0] * q.shape[1] + 1, q.shape[1], dtype=torch.int32, device="cuda"
)
max_seqlen = q.shape[1]
q = rearrange(q, "b s ... -> (b s) ...")
k = rearrange(k, "b s ... -> (b s) ...")
v = rearrange(v, "b s ... -> (b s) ...")
return q, k, v, cu_seqlens, max_seqlen
def test_attn_varlen_moba_speed(
batch,
head,
seqlen,
head_dim,
moba_chunk_size,
moba_topk,
dtype=torch.bfloat16,
select_mode="threshold",
simsum_threshold=0.25,
threshold_type="query_head",
):
"""Speed test comparing flash_attn vs moba_attention"""
# Get data
q, k, v, cu_seqlen, max_seqlen = generate_data(batch, seqlen, head, head_dim, dtype)
print(
f"batch:{batch} head:{head} seqlen:{seqlen} chunk:{moba_chunk_size} topk:{moba_topk} select_mode: {select_mode} simsum_threshold:{simsum_threshold}"
)
vo_grad = torch.randn_like(q)
# Warmup
warmup_iters = 3
perf_test_iters = 10
# Warmup
for _ in range(warmup_iters):
o = flash_attn_varlen_func(
q, k, v, cu_seqlen, cu_seqlen, max_seqlen, max_seqlen, causal=False
)
torch.autograd.backward(o, vo_grad)
torch.cuda.synchronize()
start_flash = time.perf_counter()
for _ in range(perf_test_iters):
o = flash_attn_varlen_func(
q, k, v, cu_seqlen, cu_seqlen, max_seqlen, max_seqlen, causal=False
)
torch.autograd.backward(o, vo_grad)
torch.cuda.synchronize()
time_flash = (time.perf_counter() - start_flash) / perf_test_iters * 1000
# Warmup
for _ in range(warmup_iters):
om = moba_attn_varlen(
q,
k,
v,
cu_seqlen,
max_seqlen,
moba_chunk_size=moba_chunk_size,
moba_topk=moba_topk,
select_mode=select_mode,
simsum_threshold=simsum_threshold,
threshold_type=threshold_type,
)
torch.autograd.backward(om, vo_grad)
torch.cuda.synchronize()
start_moba = time.perf_counter()
for _ in range(perf_test_iters):
om = moba_attn_varlen(
q,
k,
v,
cu_seqlen,
max_seqlen,
moba_chunk_size=moba_chunk_size,
moba_topk=moba_topk,
select_mode=select_mode,
simsum_threshold=simsum_threshold,
threshold_type=threshold_type,
)
torch.autograd.backward(om, vo_grad)
torch.cuda.synchronize()
time_moba = (time.perf_counter() - start_moba) / perf_test_iters * 1000
print(f"Flash: {time_flash:.2f}ms, MoBA: {time_moba:.2f}ms")
print(f"Speedup: {time_flash / time_moba:.2f}x")
if __name__ == "__main__":
"""
CUDA_VISIBLE_DEVICES=1 \
python -u csrc/attn/vmoba_attn/vmoba/vmoba.py
"""
test_attn_varlen_moba_speed(
batch=1,
head=12,
seqlen=32760,
head_dim=128,
moba_chunk_size=32760 // 3 // 6 // 4,
moba_topk=3,
select_mode="threshold",
simsum_threshold=0.3,
threshold_type="query_head",
)
# sgl-diffusion CLI Inference
The sgl-diffusion CLI provides a quick way to access the sgl-diffusion inference pipeline for image and video generation.
## Prerequisites
- A working sgl-diffusion installation and the `sgl-diffusion` CLI available in `$PATH`.
- Python 3.10+ if you plan to use the OpenAI Python SDK.
## Supported Arguments
### Server Arguments
- `--model-path {MODEL_PATH}`: Path to the model or model ID
- `--num-gpus {NUM_GPUS}`: Number of GPUs to use
- `--tp-size {TP_SIZE}`: Tensor parallelism size (only for the encoder; should not be larger than 1 if text encoder offload is enabled, as layer-wise offload plus prefetch is faster)
- `--sp-size {SP_SIZE}`: Sequence parallelism size (typically should match the number of GPUs)
- `--ulysses-degree {ULYSSES_DEGREE}`: The degree of DeepSpeed-Ulysses-style SP in USP
- `--ring-degree {RING_DEGREE}`: The degree of ring attention-style SP in USP
### Sampling Parameters
- `--prompt {PROMPT}`: Text description for the video you want to generate
- `--num-inference-steps {STEPS}`: Number of denoising steps
- `--negative-prompt {PROMPT}`: Negative prompt to guide generation away from certain concepts
- `--seed {SEED}`: Random seed for reproducible generation
#### Image/Video Configuration
- `--height {HEIGHT}`: Height of the generated output
- `--width {WIDTH}`: Width of the generated output
- `--num-frames {NUM_FRAMES}`: Number of frames to generate
- `--fps {FPS}`: Frames per second for the saved output, if this is a video-generation task
#### Output Options
- `--output-path {PATH}`: Directory to save the generated video
- `--save-output`: Whether to save the image/video to disk
- `--return-frames`: Whether to return the raw frames
### Using Configuration Files
Instead of specifying all parameters on the command line, you can use a configuration file:
```bash
sglang generate --config {CONFIG_FILE_PATH}
```
The configuration file should be in JSON or YAML format with the same parameter names as the CLI options. Command-line arguments take precedence over settings in the configuration file, allowing you to override specific values while keeping the rest from the configuration file.
Example configuration file (config.json):
```json
{
"model_path": "FastVideo/FastHunyuan-diffusers",
"prompt": "A beautiful woman in a red dress walking down a street",
"output_path": "outputs/",
"num_gpus": 2,
"sp_size": 2,
"tp_size": 1,
"num_frames": 45,
"height": 720,
"width": 1280,
"num_inference_steps": 6,
"seed": 1024,
"fps": 24,
"precision": "bf16",
"vae_precision": "fp16",
"vae_tiling": true,
"vae_sp": true,
"vae_config": {
"load_encoder": false,
"load_decoder": true,
"tile_sample_min_height": 256,
"tile_sample_min_width": 256
},
"text_encoder_precisions": [
"fp16",
"fp16"
],
"mask_strategy_file_path": null,
"enable_torch_compile": false
}
```
Or using YAML format (config.yaml):
```yaml
model_path: "FastVideo/FastHunyuan-diffusers"
prompt: "A beautiful woman in a red dress walking down a street"
output_path: "outputs/"
num_gpus: 2
sp_size: 2
tp_size: 1
num_frames: 45
height: 720
width: 1280
num_inference_steps: 6
seed: 1024
fps: 24
precision: "bf16"
vae_precision: "fp16"
vae_tiling: true
vae_sp: true
vae_config:
load_encoder: false
load_decoder: true
tile_sample_min_height: 256
tile_sample_min_width: 256
text_encoder_precisions:
- "fp16"
- "fp16"
mask_strategy_file_path: null
enable_torch_compile: false
```
To see all the options, you can use the `--help` flag:
```bash
sglang generate --help
```
## Serve
Launch the sgl-diffusion HTTP server and interact with it using the OpenAI SDK and curl. The server implements an OpenAI-compatible subset for Videos under the `/v1/videos` namespace.
### Start the server
Use the following command to launch the server:
```bash
SERVER_ARGS=(
--model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers
--text-encoder-cpu-offload
--pin-cpu-memory
--num-gpus 4
--ulysses-degree=2
--ring-degree=2
)
sglang serve $SERVER_ARGS
```
- **--model-path**: Which model to load. The example uses `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`.
- **--port**: HTTP port to listen on (the default here is `30010`).
Wait until the port is listening. In CI, the tests probe `127.0.0.1:30010` before sending requests.
### OpenAI Python SDK usage
Initialize the client with a dummy API key and point `base_url` to your local server:
```python
from openai import OpenAI
client = OpenAI(api_key="sk-proj-1234567890", base_url="http://localhost:30010/v1")
```
- **Create a video**
```python
video = client.videos.create(prompt="A calico cat playing a piano on stage", size="1280x720")
print(video.id, video.status)
```
Response example fields include `id`, `status` (e.g., `queued``completed`), `size`, and `seconds`.
- **List videos**
```python
videos = client.videos.list()
for item in videos.data:
print(item.id, item.status)
```
- **Poll for completion and download content**
```python
import time
video = client.videos.create(prompt="A calico cat playing a piano on stage", size="1280x720")
video_id = video.id
# Simple polling loop
while True:
page = client.videos.list()
item = next((v for v in page.data if v.id == video_id), None)
if item and item.status == "completed":
break
time.sleep(5)
# Download binary content (MP4)
resp = client.videos.download_content(video_id=video_id)
content = resp.read() # bytes
with open("output.mp4", "wb") as f:
f.write(content)
```
### curl examples
- **Create a video**
```bash
curl -sS -X POST "http://localhost:30010/v1/videos" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-proj-1234567890" \
-d '{
"prompt": "A calico cat playing a piano on stage",
"size": "1280x720"
}'
```
- **List videos**
```bash
curl -sS -X GET "http://localhost:30010/v1/videos" \
-H "Authorization: Bearer sk-proj-1234567890"
```
- **Download video content**
```bash
curl -sS -L "http://localhost:30010/v1/videos/<VIDEO_ID>/content" \
-H "Authorization: Bearer sk-proj-1234567890" \
-o output.mp4
```
### API surface implemented here
The server exposes these endpoints (OpenAPI tag `videos`):
- `POST /v1/videos` — Create a generation job and return a queued `video` object.
- `GET /v1/videos` — List jobs.
- `GET /v1/videos/{video_id}/content` — Download binary content when ready (e.g., MP4).
### Reference
- OpenAI Videos API reference: `https://platform.openai.com/docs/api-reference/videos`
## Generate
Run a one-off generation task without launching a persistent server.
To use it, pass both server arguments and sampling parameters in one command, after the `generate` subcommand, for example:
```bash
SERVER_ARGS=(
--model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers
--text-encoder-cpu-offload
--pin-cpu-memory
--num-gpus 4
--ulysses-degree=2
--ring-degree=2
)
SAMPLING_ARGS=(
--prompt "A curious raccoon"
--save-output
--output-path outputs
--output-file-name "A curious raccoon.mp4"
)
sglang generate $SERVER_ARGS $SAMPLING_ARGS
```
Once the generation task has finished, the server will shut down automatically.
> [!NOTE]
> The HTTP server-related arguments are ignored in this subcommand.
# Install sgl-diffusion
You can install sgl-diffusion using one of the methods below.
This page primarily applies to common NVIDIA GPU platforms.
## Method 1: With pip or uv
It is recommended to use uv for a faster installation:
```bash
pip install --upgrade pip
pip install uv
uv pip install sglang[.diffusion] --prerelease=allow
```
## Method 2: From source
```bash
# Use the latest release branch
git clone https://github.com/sgl-project/sglang.git
cd sglang
# Install the Python packages
pip install --upgrade pip
pip install -e "python/.[diffusion]"
# With uv
uv pip install --prerelease=allow -e "python/.[diffusion]"
```
**Quick fixes for common problems:**
- If you want to develop sgl-diffusion, it is recommended to use Docker. The Docker image is `lmsysorg/sgl-diffusion:latest`.
## Method 3: Using Docker
The Docker images are available on Docker Hub at [lmsysorg/sgl-diffusion](), built from the [Dockerfile](https://github.com/sgl-project/sgl-diffusion/tree/main/docker).
Replace `<secret>` below with your HuggingFace Hub [token](https://huggingface.co/docs/hub/en/security-tokens).
```bash
docker run --gpus all \
--shm-size 32g \
-p 30000:30000 \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HF_TOKEN=<secret>" \
--ipc=host \
lmsysorg/sglang:diffusion \
sglang generate --model-path black-forest-labs/FLUX.1-dev \
--prompt "A logo With Bold Large text: SGL Diffusion" \
--save-output
```
# Compatibility Matrix
The table below shows every supported model and the optimizations supported for them.
The symbols used have the following meanings:
- ✅ = Full compatibility
- ❌ = No compatibility
- ⭕ = Does not apply to this model
## Models x Optimization
The `HuggingFace Model ID` can be passed directly to `from_pretrained()` methods, and sgl-diffusion will use the optimal
default parameters when initializing and generating videos.
### Video Generation Models
| Model Name | Hugging Face Model ID | Resolutions | TeaCache | Sliding Tile Attn | Sage Attn | Video Sparse Attention (VSA) |
|:-----------------------------|:--------------------------------------------------|:---------------------------------------------|:--------:|:-----------------:|:---------:|:----------------------------:|
| FastWan2.1 T2V 1.3B | `FastVideo/FastWan2.1-T2V-1.3B-Diffusers` | 480p | ⭕ | ⭕ | ⭕ | ✅ |
| FastWan2.2 TI2V 5B Full Attn | `FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers` | 720p | ⭕ | ⭕ | ⭕ | ✅ |
| Wan2.2 TI2V 5B | `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | 720p | ⭕ | ⭕ | ✅ | ⭕ |
| Wan2.2 T2V A14B | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | 480p<br>720p | ❌ | ❌ | ✅ | ⭕ |
| Wan2.2 I2V A14B | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | 480p<br>720p | ❌ | ❌ | ✅ | ⭕ |
| HunyuanVideo | `hunyuanvideo-community/HunyuanVideo` | 720×1280<br>544×960 | ❌ | ✅ | ✅ | ⭕ |
| FastHunyuan | `FastVideo/FastHunyuan-diffusers` | 720×1280<br>544×960 | ❌ | ✅ | ✅ | ⭕ |
| Wan2.1 T2V 1.3B | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` | 480p | ✅ | ✅ | ✅ | ⭕ |
| Wan2.1 T2V 14B | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | 480p, 720p | ✅ | ✅ | ✅ | ⭕ |
| Wan2.1 I2V 480P | `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` | 480p | ✅ | ✅ | ✅ | ⭕ |
| Wan2.1 I2V 720P | `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` | 720p | ✅ | ✅ | ✅ | ⭕ |
**Note**: Wan2.2 TI2V 5B has some quality issues when performing I2V generation. We are working on fixing this issue.
### Image Generation Models
| Model Name | HuggingFace Model ID | Resolutions | TeaCache | Sage Attn |
|:----------------|:-------------------------------|:---------------|:--------:|:---------:|
| FLUX.1-dev | `black-forest-labs/FLUX.1-dev` | Any resolution | ❌ | ❌ |
| Qwen Image | `Qwen/Qwen-Image` | Any resolution | ❌ | ❌ |
| Qwen Image Edit | `Qwen/Qwen-Image-Edit` | Any resolution | ❌ | ❌ |
## Special requirements
### Sliding Tile Attention
- Currently, only Hopper GPUs (H100s) are supported.
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import importlib.util
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/envs.py
import logging
import os
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
import diffusers
import torch
from packaging import version
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
SGL_DIFFUSION_RINGBUFFER_WARNING_INTERVAL: int = 60
SGL_DIFFUSION_NCCL_SO_PATH: str | None = None
LD_LIBRARY_PATH: str | None = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: str | None = None
SGL_DIFFUSION_CACHE_ROOT: str = os.path.expanduser("~/.cache/sgl_diffusion")
SGL_DIFFUSION_CONFIG_ROOT: str = os.path.expanduser("~/.config/sgl_diffusion")
SGL_DIFFUSION_CONFIGURE_LOGGING: int = 1
SGL_DIFFUSION_LOGGING_LEVEL: str = "INFO"
SGL_DIFFUSION_LOGGING_PREFIX: str = ""
SGL_DIFFUSION_LOGGING_CONFIG_PATH: str | None = None
SGL_DIFFUSION_TRACE_FUNCTION: int = 0
SGL_DIFFUSION_WORKER_MULTIPROC_METHOD: str = "fork"
SGL_DIFFUSION_TARGET_DEVICE: str = "cuda"
MAX_JOBS: str | None = None
NVCC_THREADS: str | None = None
CMAKE_BUILD_TYPE: str | None = None
VERBOSE: bool = False
SGL_DIFFUSION_SERVER_DEV_MODE: bool = False
SGL_DIFFUSION_STAGE_LOGGING: bool = False
def _is_hip():
has_rocm = torch.version.hip is not None
return has_rocm
def _is_cuda():
has_cuda = torch.version.cuda is not None
return has_cuda
def _is_musa():
try:
if hasattr(torch, "musa") and torch.musa.is_available():
return True
except ModuleNotFoundError:
return False
def _is_mps():
return torch.backends.mps.is_available()
class PackagesEnvChecker:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(PackagesEnvChecker, cls).__new__(cls)
cls._instance.initialize()
return cls._instance
def initialize(self):
self.packages_info = {
"has_aiter": self.check_aiter(),
"diffusers_version": self.check_diffusers_version(),
}
def check_aiter(self):
"""
Checks whether ROCm AITER library is installed
"""
try:
logger.info("Using AITER as the attention library")
return True
except:
if _is_hip():
logger.warning(
f'Using AMD GPUs, but library "aiter" is not installed, '
"defaulting to other attention mechanisms"
)
return False
def check_flash_attn(self):
if not torch.cuda.is_available():
return False
if _is_musa():
logger.info(
"Flash Attention library is not supported on MUSA for the moment."
)
return False
try:
return True
except ImportError:
logger.warning(
f'Flash Attention library "flash_attn" not found, '
f"using pytorch attention implementation"
)
return False
def check_long_ctx_attn(self):
if not torch.cuda.is_available():
return False
try:
return importlib.util.find_spec("yunchang") is not None
except ImportError:
logger.warning(
f'Ring Flash Attention library "yunchang" not found, '
f"using pytorch attention implementation"
)
return False
def check_diffusers_version(self):
if version.parse(
version.parse(diffusers.__version__).base_version
) < version.parse("0.30.0"):
raise RuntimeError(
f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported,"
f"please upgrade to version > 0.30.0"
)
return version.parse(version.parse(diffusers.__version__).base_version)
def get_packages_info(self):
return self.packages_info
PACKAGES_CHECKER = PackagesEnvChecker()
def get_default_cache_root() -> str:
return os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache"),
)
def get_default_config_root() -> str:
return os.getenv(
"XDG_CONFIG_HOME",
os.path.join(os.path.expanduser("~"), ".config"),
)
def maybe_convert_int(value: str | None) -> int | None:
if value is None:
return None
return int(value)
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
# begin-env-vars-definition
environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ==================
# Target device of sgl-diffusion, supporting [cuda (by default),
# rocm, neuron, cpu, openvino]
"SGL_DIFFUSION_TARGET_DEVICE": lambda: os.getenv(
"SGL_DIFFUSION_TARGET_DEVICE", "cuda"
),
# Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
# Number of threads to use for nvcc
# By default this is 1.
# If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU.
"NVCC_THREADS": lambda: os.getenv("NVCC_THREADS", None),
# If set, sgl_diffusion will use precompiled binaries (*.so)
"SGL_DIFFUSION_USE_PRECOMPILED": lambda: bool(
os.environ.get("SGL_DIFFUSION_USE_PRECOMPILED")
)
or bool(os.environ.get("SGL_DIFFUSION_PRECOMPILED_WHEEL_LOCATION")),
# CMake build type
# If not set, defaults to "Debug" or "RelWithDebInfo"
# Available options: "Debug", "Release", "RelWithDebInfo"
"CMAKE_BUILD_TYPE": lambda: os.getenv("CMAKE_BUILD_TYPE"),
# If set, sgl_diffusion will print verbose logs during installation
"VERBOSE": lambda: bool(int(os.getenv("VERBOSE", "0"))),
# Root directory for FASTVIDEO configuration files
# Defaults to `~/.config/sgl_diffusion` unless `XDG_CONFIG_HOME` is set
# Note that this not only affects how sgl_diffusion finds its configuration files
# during runtime, but also affects how sgl_diffusion installs its configuration
# files during **installation**.
"SGL_DIFFUSION_CONFIG_ROOT": lambda: os.path.expanduser(
os.getenv(
"SGL_DIFFUSION_CONFIG_ROOT",
os.path.join(get_default_config_root(), "sgl_diffusion"),
)
),
# ================== Runtime Env Vars ==================
# Root directory for FASTVIDEO cache files
# Defaults to `~/.cache/sgl_diffusion` unless `XDG_CACHE_HOME` is set
"SGL_DIFFUSION_CACHE_ROOT": lambda: os.path.expanduser(
os.getenv(
"SGL_DIFFUSION_CACHE_ROOT",
os.path.join(get_default_cache_root(), "sgl_diffusion"),
)
),
# Interval in seconds to log a warning message when the ring buffer is full
"SGL_DIFFUSION_RINGBUFFER_WARNING_INTERVAL": lambda: int(
os.environ.get("SGL_DIFFUSION_RINGBUFFER_WARNING_INTERVAL", "60")
),
# Path to the NCCL library file. It is needed because nccl>=2.19 brought
# by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234
"SGL_DIFFUSION_NCCL_SO_PATH": lambda: os.environ.get(
"SGL_DIFFUSION_NCCL_SO_PATH", None
),
# when `SGL_DIFFUSION_NCCL_SO_PATH` is not set, sgl_diffusion will try to find the nccl
# library file in the locations specified by `LD_LIBRARY_PATH`
"LD_LIBRARY_PATH": lambda: os.environ.get("LD_LIBRARY_PATH", None),
# Internal flag to enable Dynamo fullgraph capture
"SGL_DIFFUSION_TEST_DYNAMO_FULLGRAPH_CAPTURE": lambda: bool(
os.environ.get("SGL_DIFFUSION_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"
),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
# used to control the visible devices in the distributed setting
"CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None),
# timeout for each iteration in the engine
"SGL_DIFFUSION_ENGINE_ITERATION_TIMEOUT_S": lambda: int(
os.environ.get("SGL_DIFFUSION_ENGINE_ITERATION_TIMEOUT_S", "60")
),
# Logging configuration
# If set to 0, sgl_diffusion will not configure logging
# If set to 1, sgl_diffusion will configure logging using the default configuration
# or the configuration file specified by SGL_DIFFUSION_LOGGING_CONFIG_PATH
"SGL_DIFFUSION_CONFIGURE_LOGGING": lambda: int(
os.getenv("SGL_DIFFUSION_CONFIGURE_LOGGING", "1")
),
"SGL_DIFFUSION_LOGGING_CONFIG_PATH": lambda: os.getenv(
"SGL_DIFFUSION_LOGGING_CONFIG_PATH"
),
# this is used for configuring the default logging level
"SGL_DIFFUSION_LOGGING_LEVEL": lambda: os.getenv(
"SGL_DIFFUSION_LOGGING_LEVEL", "INFO"
),
# if set, SGL_DIFFUSION_LOGGING_PREFIX will be prepended to all log messages
"SGL_DIFFUSION_LOGGING_PREFIX": lambda: os.getenv(
"SGL_DIFFUSION_LOGGING_PREFIX", ""
),
# Trace function calls
# If set to 1, sgl_diffusion will trace function calls
# Useful for debugging
"SGL_DIFFUSION_TRACE_FUNCTION": lambda: int(
os.getenv("SGL_DIFFUSION_TRACE_FUNCTION", "0")
),
# Path to the attention configuration file. Only used for sliding tile
# attention for now.
"SGL_DIFFUSION_ATTENTION_CONFIG": lambda: (
None
if os.getenv("SGL_DIFFUSION_ATTENTION_CONFIG", None) is None
else os.path.expanduser(os.getenv("SGL_DIFFUSION_ATTENTION_CONFIG", "."))
),
# Use dedicated multiprocess context for workers.
# Both spawn and fork work
"SGL_DIFFUSION_WORKER_MULTIPROC_METHOD": lambda: os.getenv(
"SGL_DIFFUSION_WORKER_MULTIPROC_METHOD", "fork"
),
# Enables torch profiler if set. Path to the directory where torch profiler
# traces are saved. Note that it must be an absolute path.
"SGL_DIFFUSION_TORCH_PROFILER_DIR": lambda: (
None
if os.getenv("SGL_DIFFUSION_TORCH_PROFILER_DIR", None) is None
else os.path.expanduser(os.getenv("SGL_DIFFUSION_TORCH_PROFILER_DIR", "."))
),
# If set, sgl_diffusion will run in development mode, which will enable
# some additional endpoints for developing and debugging,
# e.g. `/reset_prefix_cache`
"SGL_DIFFUSION_SERVER_DEV_MODE": lambda: bool(
int(os.getenv("SGL_DIFFUSION_SERVER_DEV_MODE", "0"))
),
# If set, sgl_diffusion will enable stage logging, which will print the time
# taken for each stage
"SGL_DIFFUSION_STAGE_LOGGING": lambda: bool(
int(os.getenv("SGL_DIFFUSION_STAGE_LOGGING", "0"))
),
}
# end-env-vars-definition
def __getattr__(name: str):
# lazy evaluation of environment variables
if name in environment_variables:
return environment_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(environment_variables.keys())
def get_torch_distributed_backend() -> str:
if torch.cuda.is_available():
return "nccl"
elif _is_musa():
return "mccl"
elif _is_mps():
return "gloo"
else:
raise NotImplementedError(
"No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available"
)
def get_device(local_rank: int) -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda", local_rank)
elif _is_musa():
return torch.device("musa", local_rank)
elif _is_mps():
return torch.device("mps")
else:
return torch.device("cpu")
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Basic inference pipelines for sglang.multimodal_gen.
This package contains basic pipelines for video and image generation.
"""
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Hunyuan video diffusion pipeline implementation.
This module contains an implementation of the Hunyuan video diffusion pipeline
using the modular pipeline architecture.
"""
from sglang.multimodal_gen.runtime.pipelines import ComposedPipelineBase, Req
from sglang.multimodal_gen.runtime.pipelines.stages import (
ConditioningStage,
DecodingStage,
DenoisingStage,
InputValidationStage,
LatentPreparationStage,
TextEncodingStage,
TimestepPreparationStage,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
# TODO(will): move PRECISION_TO_TYPE to better place
logger = init_logger(__name__)
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
def prepare_mu(batch: Req, server_args: ServerArgs):
height = batch.height
width = batch.width
vae_scale_factor = (
server_args.pipeline_config.vae_config.arch_config.vae_scale_factor
)
image_seq_len = (int(height) // vae_scale_factor) * (int(width) // vae_scale_factor)
mu = calculate_shift(
image_seq_len,
# hard code, since scheduler_config is not in PipelineConfig now
256,
4096,
0.5,
1.15,
)
return "mu", mu
class FluxPipeline(ComposedPipelineBase):
pipeline_name = "FluxPipeline"
_required_config_modules = [
"text_encoder",
"text_encoder_2",
"tokenizer",
"tokenizer_2",
"vae",
"transformer",
"scheduler",
]
def create_pipeline_stages(self, server_args: ServerArgs):
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(
stage_name="input_validation_stage", stage=InputValidationStage()
)
self.add_stage(
stage_name="prompt_encoding_stage_primary",
stage=TextEncodingStage(
text_encoders=[
self.get_module("text_encoder"),
self.get_module("text_encoder_2"),
],
tokenizers=[
self.get_module("tokenizer"),
self.get_module("tokenizer_2"),
],
),
)
self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())
self.add_stage(
stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(
scheduler=self.get_module("scheduler"),
prepare_extra_set_timesteps_kwargs=[prepare_mu],
),
)
self.add_stage(
stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer"),
),
)
self.add_stage(
stage_name="denoising_stage",
stage=DenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
self.add_stage(
stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"))
)
EntryClass = FluxPipeline
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Hunyuan video diffusion pipeline implementation.
This module contains an implementation of the Hunyuan video diffusion pipeline
using the modular pipeline architecture.
"""
from sglang.multimodal_gen.runtime.pipelines import ComposedPipelineBase
from sglang.multimodal_gen.runtime.pipelines.stages import (
ConditioningStage,
DecodingStage,
DenoisingStage,
InputValidationStage,
LatentPreparationStage,
TextEncodingStage,
TimestepPreparationStage,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
# TODO(will): move PRECISION_TO_TYPE to better place
logger = init_logger(__name__)
class HunyuanVideoPipeline(ComposedPipelineBase):
pipeline_name = "HunyuanVideoPipeline"
_required_config_modules = [
"text_encoder",
"text_encoder_2",
"tokenizer",
"tokenizer_2",
"vae",
"transformer",
"scheduler",
]
def create_pipeline_stages(self, server_args: ServerArgs):
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(
stage_name="input_validation_stage", stage=InputValidationStage()
)
self.add_stage(
stage_name="prompt_encoding_stage_primary",
stage=TextEncodingStage(
text_encoders=[
self.get_module("text_encoder"),
self.get_module("text_encoder_2"),
],
tokenizers=[
self.get_module("tokenizer"),
self.get_module("tokenizer_2"),
],
),
)
self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())
self.add_stage(
stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")),
)
self.add_stage(
stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer"),
),
)
self.add_stage(
stage_name="denoising_stage",
stage=DenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
self.add_stage(
stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"))
)
EntryClass = HunyuanVideoPipeline
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Hunyuan video diffusion pipeline implementation.
This module contains an implementation of the Hunyuan video diffusion pipeline
using the modular pipeline architecture.
"""
from diffusers.image_processor import VaeImageProcessor
from sglang.multimodal_gen.runtime.pipelines import ComposedPipelineBase, Req
from sglang.multimodal_gen.runtime.pipelines.stages import (
ConditioningStage,
DecodingStage,
DenoisingStage,
ImageEncodingStage,
ImageVAEEncodingStage,
InputValidationStage,
LatentPreparationStage,
TextEncodingStage,
TimestepPreparationStage,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
# TODO(will): move PRECISION_TO_TYPE to better place
logger = init_logger(__name__)
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
def prepare_mu(batch: Req, server_args: ServerArgs):
height = batch.height
width = batch.width
vae_scale_factor = server_args.pipeline_config.vae_config.vae_scale_factor
image_seq_len = (int(height) // vae_scale_factor) * (int(width) // vae_scale_factor)
mu = calculate_shift(
image_seq_len,
# hard code, since scheduler_config is not in PipelineConfig now
256,
4096,
0.5,
1.15,
)
return "mu", mu
class QwenImagePipeline(ComposedPipelineBase):
pipeline_name = "QwenImagePipeline"
_required_config_modules = [
"text_encoder",
"tokenizer",
"vae",
"transformer",
"scheduler",
]
def create_pipeline_stages(self, server_args: ServerArgs):
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(
stage_name="input_validation_stage", stage=InputValidationStage()
)
self.add_stage(
stage_name="prompt_encoding_stage_primary",
stage=TextEncodingStage(
text_encoders=[
self.get_module("text_encoder"),
],
tokenizers=[
self.get_module("tokenizer"),
],
),
)
self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())
self.add_stage(
stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(
scheduler=self.get_module("scheduler"),
prepare_extra_set_timesteps_kwargs=[prepare_mu],
),
)
self.add_stage(
stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer"),
),
)
self.add_stage(
stage_name="denoising_stage",
stage=DenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
self.add_stage(
stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"))
)
class QwenImageEditPipeline(ComposedPipelineBase):
pipeline_name = "QwenImageEditPipeline"
_required_config_modules = [
"processor",
"scheduler",
"text_encoder",
"tokenizer",
"transformer",
"vae",
]
def create_pipeline_stages(self, server_args: ServerArgs):
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(
stage_name="input_validation_stage", stage=InputValidationStage()
)
self.add_stage(
stage_name="prompt_encoding_stage_primary",
stage=ImageEncodingStage(
image_processor=self.get_module("processor"),
text_encoder=self.get_module("text_encoder"),
vae_image_processor=VaeImageProcessor(
vae_scale_factor=server_args.pipeline_config.vae_config.arch_config.vae_scale_factor
* 2
),
),
)
self.add_stage(
stage_name="image_encoding_stage_primary",
stage=ImageVAEEncodingStage(
vae_image_processor=VaeImageProcessor(
vae_scale_factor=server_args.pipeline_config.vae_config.arch_config.vae_scale_factor
* 2
),
vae=self.get_module("vae"),
),
)
self.add_stage(
stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(
scheduler=self.get_module("scheduler"),
prepare_extra_set_timesteps_kwargs=[prepare_mu],
),
)
self.add_stage(
stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer"),
),
)
self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())
self.add_stage(
stage_name="denoising_stage",
stage=DenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
self.add_stage(
stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"))
)
EntryClass = [QwenImagePipeline, QwenImageEditPipeline]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# type: ignore
# SPDX-License-Identifier: Apache-2.0
"""
Hunyuan video diffusion pipeline implementation.
This module contains an implementation of the Hunyuan video diffusion pipeline
using the modular pipeline architecture.
"""
import os
from typing import Any
import torch
from huggingface_hub import hf_hub_download
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.loader.component_loader import (
PipelineComponentLoader,
)
from sglang.multimodal_gen.runtime.models.encoders.bert import (
HunyuanClip, # type: ignore
)
from sglang.multimodal_gen.runtime.models.encoders.stepllm import STEP1TextEncoder
from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.pipelines.lora_pipeline import LoRAPipeline
from sglang.multimodal_gen.runtime.pipelines.stages import (
DecodingStage,
DenoisingStage,
InputValidationStage,
LatentPreparationStage,
StepvideoPromptEncodingStage,
TimestepPreparationStage,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class StepVideoPipeline(LoRAPipeline, ComposedPipelineBase):
pipeline_name = "StepVideoPipeline"
_required_config_modules = ["transformer", "scheduler", "vae"]
def create_pipeline_stages(self, server_args: ServerArgs):
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(
stage_name="input_validation_stage", stage=InputValidationStage()
)
self.add_stage(
stage_name="prompt_encoding_stage",
stage=StepvideoPromptEncodingStage(
stepllm=self.get_module("text_encoder"),
clip=self.get_module("text_encoder_2"),
),
)
self.add_stage(
stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")),
)
self.add_stage(
stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer"),
),
)
self.add_stage(
stage_name="denoising_stage",
stage=DenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
self.add_stage(
stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"))
)
def build_llm(self, model_dir, device) -> torch.nn.Module:
text_encoder = (
STEP1TextEncoder(model_dir, max_length=320).to(torch.bfloat16).eval()
)
return text_encoder
def build_clip(self, model_dir, device) -> HunyuanClip:
clip = HunyuanClip(model_dir, max_length=77).eval()
return clip
def initialize_pipeline(self, server_args: ServerArgs):
"""
Initialize the pipeline.
"""
target_device = get_local_torch_device()
llm_dir = os.path.join(self.model_path, "step_llm")
clip_dir = os.path.join(self.model_path, "hunyuan_clip")
text_enc = self.build_llm(llm_dir, target_device)
clip_enc = self.build_clip(clip_dir, target_device)
self.add_module("text_encoder", text_enc)
self.add_module("text_encoder_2", clip_enc)
lib_path = (
os.path.join(
server_args.model_path,
"lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so",
)
if os.path.isdir(server_args.model_path) # local checkout
else hf_hub_download(
repo_id=server_args.model_path,
filename="lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so",
)
)
torch.ops.load_library(lib_path)
def load_modules(
self,
server_args: ServerArgs,
loaded_modules: dict[str, torch.nn.Module] | None = None,
) -> dict[str, Any]:
"""
Load the modules from the config.
"""
model_index = self._load_config()
logger.info("Loading pipeline modules from config: %s", model_index)
# remove keys that are not pipeline modules
model_index.pop("_class_name")
model_index.pop("_diffusers_version")
# some sanity checks
assert (
len(model_index) > 1
), "model_index.json must contain at least one pipeline module"
required_modules = ["transformer", "scheduler", "vae"]
for module_name in required_modules:
if module_name not in model_index:
raise ValueError(
f"model_index.json must contain a {module_name} module"
)
logger.info("Diffusers config passed sanity checks")
# all the component models used by the pipeline
modules = {}
for module_name, (
transformers_or_diffusers,
architecture,
) in model_index.items():
component_model_path = os.path.join(self.model_path, module_name)
module = PipelineComponentLoader.load_module(
module_name=module_name,
component_model_path=component_model_path,
transformers_or_diffusers=transformers_or_diffusers,
server_args=server_args,
)
logger.info("Loaded module %s from %s", module_name, component_model_path)
if module_name in modules:
logger.warning("Overwriting module %s", module_name)
modules[module_name] = module
required_modules = self.required_config_modules
# Check if all required modules were loaded
for module_name in required_modules:
if module_name not in modules or modules[module_name] is None:
raise ValueError(
f"Required module {module_name} was not loaded properly"
)
return modules
EntryClass = StepVideoPipeline
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Wan causal DMD pipeline implementation.
This module wires the causal DMD denoising stage into the modular pipeline.
"""
from sglang.multimodal_gen.runtime.pipelines import ComposedPipelineBase, LoRAPipeline
# isort: off
from sglang.multimodal_gen.runtime.pipelines.stages import (
ConditioningStage,
DecodingStage,
CausalDMDDenoisingStage,
InputValidationStage,
LatentPreparationStage,
TextEncodingStage,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
# isort: on
logger = init_logger(__name__)
class WanCausalDMDPipeline(LoRAPipeline, ComposedPipelineBase):
pipeline_name = "WanPipeline"
_required_config_modules = [
"text_encoder",
"tokenizer",
"vae",
"transformer",
"scheduler",
]
def create_pipeline_stages(self, server_args: ServerArgs) -> None:
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(
stage_name="input_validation_stage", stage=InputValidationStage()
)
self.add_stage(
stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
),
)
self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())
self.add_stage(
stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer", None),
),
)
self.add_stage(
stage_name="denoising_stage",
stage=CausalDMDDenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
self.add_stage(
stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"))
)
EntryClass = WanCausalDMDPipeline
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Wan video diffusion pipeline implementation.
This module contains an implementation of the Wan video diffusion pipeline
using the modular pipeline architecture.
"""
from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_match_euler_discrete import (
FlowMatchEulerDiscreteScheduler,
)
from sglang.multimodal_gen.runtime.pipelines import ComposedPipelineBase, LoRAPipeline
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
# isort: off
from sglang.multimodal_gen.runtime.pipelines.stages import (
ConditioningStage,
DecodingStage,
DmdDenoisingStage,
InputValidationStage,
LatentPreparationStage,
TextEncodingStage,
TimestepPreparationStage,
)
# isort: on
logger = init_logger(__name__)
class WanDMDPipeline(LoRAPipeline, ComposedPipelineBase):
"""
Wan video diffusion pipeline with LoRA support.
"""
pipeline_name = "WanDMDPipeline"
_required_config_modules = [
"text_encoder",
"tokenizer",
"vae",
"transformer",
"scheduler",
]
def initialize_pipeline(self, server_args: ServerArgs):
self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler(
shift=server_args.pipeline_config.flow_shift
)
def create_pipeline_stages(self, server_args: ServerArgs) -> None:
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(
stage_name="input_validation_stage", stage=InputValidationStage()
)
self.add_stage(
stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
),
)
self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())
self.add_stage(
stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")),
)
self.add_stage(
stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer", None),
),
)
self.add_stage(
stage_name="denoising_stage",
stage=DmdDenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
self.add_stage(
stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"))
)
EntryClass = WanDMDPipeline
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Wan video diffusion pipeline implementation.
This module contains an implementation of the Wan video diffusion pipeline
using the modular pipeline architecture.
"""
from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.pipelines.lora_pipeline import LoRAPipeline
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
# isort: off
from sglang.multimodal_gen.runtime.pipelines.stages import (
ImageEncodingStage,
ConditioningStage,
DecodingStage,
DmdDenoisingStage,
ImageVAEEncodingStage,
InputValidationStage,
LatentPreparationStage,
TextEncodingStage,
TimestepPreparationStage,
)
# isort: on
from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_match_euler_discrete import (
FlowMatchEulerDiscreteScheduler,
)
logger = init_logger(__name__)
class WanImageToVideoDmdPipeline(LoRAPipeline, ComposedPipelineBase):
pipeline_name = "WanCausalDMDPipeline"
_required_config_modules = [
"text_encoder",
"tokenizer",
"vae",
"transformer",
"scheduler",
"image_encoder",
"image_processor",
]
def initialize_pipeline(self, server_args: ServerArgs):
self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler(
shift=server_args.pipeline_config.flow_shift
)
def create_pipeline_stages(self, server_args: ServerArgs):
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(
stage_name="input_validation_stage", stage=InputValidationStage()
)
self.add_stage(
stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
),
)
self.add_stage(
stage_name="image_encoding_stage",
stage=ImageEncodingStage(
image_encoder=self.get_module("image_encoder"),
image_processor=self.get_module("image_processor"),
),
)
self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())
self.add_stage(
stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")),
)
self.add_stage(
stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer"),
),
)
self.add_stage(
stage_name="image_latent_preparation_stage",
stage=ImageVAEEncodingStage(vae=self.get_module("vae")),
)
self.add_stage(
stage_name="denoising_stage",
stage=DmdDenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
self.add_stage(
stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"))
)
EntryClass = WanImageToVideoDmdPipeline
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Wan video diffusion pipeline implementation.
This module contains an implementation of the Wan video diffusion pipeline
using the modular pipeline architecture.
"""
from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.pipelines.lora_pipeline import LoRAPipeline
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
# isort: off
from sglang.multimodal_gen.runtime.pipelines.stages import (
ImageEncodingStage,
ConditioningStage,
DecodingStage,
DenoisingStage,
ImageVAEEncodingStage,
InputValidationStage,
LatentPreparationStage,
TextEncodingStage,
TimestepPreparationStage,
)
# isort: on
from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_unipc_multistep import (
FlowUniPCMultistepScheduler,
)
logger = init_logger(__name__)
class WanImageToVideoPipeline(LoRAPipeline, ComposedPipelineBase):
pipeline_name = "WanImageToVideoPipeline"
_required_config_modules = [
"text_encoder",
"tokenizer",
"vae",
"transformer",
"scheduler",
"image_encoder",
"image_processor",
]
def initialize_pipeline(self, server_args: ServerArgs):
self.modules["scheduler"] = FlowUniPCMultistepScheduler(
shift=server_args.pipeline_config.flow_shift
)
def create_pipeline_stages(self, server_args: ServerArgs):
"""Set up pipeline stages with proper dependency injection."""
self.add_stage(
stage_name="input_validation_stage", stage=InputValidationStage()
)
self.add_stage(
stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
),
)
if (
self.get_module("image_encoder") is not None
and self.get_module("image_processor") is not None
):
self.add_stage(
stage_name="image_encoding_stage",
stage=ImageEncodingStage(
image_encoder=self.get_module("image_encoder"),
image_processor=self.get_module("image_processor"),
),
)
self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())
self.add_stage(
stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")),
)
self.add_stage(
stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer"),
),
)
self.add_stage(
stage_name="image_latent_preparation_stage",
stage=ImageVAEEncodingStage(vae=self.get_module("vae")),
)
self.add_stage(
stage_name="denoising_stage",
stage=DenoisingStage(
transformer=self.get_module("transformer"),
transformer_2=self.get_module("transformer_2"),
scheduler=self.get_module("scheduler"),
),
)
self.add_stage(
stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"))
)
EntryClass = WanImageToVideoPipeline
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment