Unverified Commit 7bc1dae0 authored by Mick's avatar Mick Committed by GitHub
Browse files

WIP: initial multimodal-gen support (#12484)


Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarJiLi <leege233@gmail.com>
Co-authored-by: default avatarCHEN Xi <78632976+RubiaCx@users.noreply.github.com>
Co-authored-by: default avatarlaixin <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarSolitaryThinker <wlsaidhi@gmail.com>
Co-authored-by: default avatarjzhang38 <a1286225768@gmail.com>
Co-authored-by: default avatarBrianChen1129 <yongqichcd@gmail.com>
Co-authored-by: default avatarKevin Lin <42618777+kevin314@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarrlsu9 <r3su@ucsd.edu>
Co-authored-by: default avatarJinzhe Pan <48981407+eigensystem@users.noreply.github.com>
Co-authored-by: default avatarforeverpiano <pianoqwz@qq.com>
Co-authored-by: default avatarRandNMR73 <notomatthew31@gmail.com>
Co-authored-by: default avatarPorridgeSwim <yz3883@columbia.edu>
Co-authored-by: default avatarJiali Chen <90408393+gary-chenjl@users.noreply.github.com>
parent 4fe53e58
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
import torch
from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig
def is_double_block(n: str, m) -> bool:
return "double" in n and str.isdigit(n.split(".")[-1])
def is_single_block(n: str, m) -> bool:
return "single" in n and str.isdigit(n.split(".")[-1])
def is_refiner_block(n: str, m) -> bool:
return "refiner" in n and str.isdigit(n.split(".")[-1])
def is_txt_in(n: str, m) -> bool:
return n.split(".")[-1] == "txt_in"
@dataclass
class HunyuanVideoArchConfig(DiTArchConfig):
_fsdp_shard_conditions: list = field(
default_factory=lambda: [is_double_block, is_single_block, is_refiner_block]
)
_compile_conditions: list = field(
default_factory=lambda: [is_double_block, is_single_block, is_txt_in]
)
param_names_mapping: dict = field(
default_factory=lambda: {
# 1. context_embedder.time_text_embed submodules (specific rules, applied first):
r"^context_embedder\.time_text_embed\.timestep_embedder\.linear_1\.(.*)$": r"txt_in.t_embedder.mlp.fc_in.\1",
r"^context_embedder\.time_text_embed\.timestep_embedder\.linear_2\.(.*)$": r"txt_in.t_embedder.mlp.fc_out.\1",
r"^context_embedder\.proj_in\.(.*)$": r"txt_in.input_embedder.\1",
r"^context_embedder\.time_text_embed\.text_embedder\.linear_1\.(.*)$": r"txt_in.c_embedder.fc_in.\1",
r"^context_embedder\.time_text_embed\.text_embedder\.linear_2\.(.*)$": r"txt_in.c_embedder.fc_out.\1",
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm1\.(.*)$": r"txt_in.refiner_blocks.\1.norm1.\2",
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm2\.(.*)$": r"txt_in.refiner_blocks.\1.norm2.\2",
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_q\.(.*)$": (
r"txt_in.refiner_blocks.\1.self_attn_qkv.\2",
0,
3,
),
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_k\.(.*)$": (
r"txt_in.refiner_blocks.\1.self_attn_qkv.\2",
1,
3,
),
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_v\.(.*)$": (
r"txt_in.refiner_blocks.\1.self_attn_qkv.\2",
2,
3,
),
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_out\.0\.(.*)$": r"txt_in.refiner_blocks.\1.self_attn_proj.\2",
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.ff\.net\.0(?:\.proj)?\.(.*)$": r"txt_in.refiner_blocks.\1.mlp.fc_in.\2",
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.ff\.net\.2(?:\.proj)?\.(.*)$": r"txt_in.refiner_blocks.\1.mlp.fc_out.\2",
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm_out\.linear\.(.*)$": r"txt_in.refiner_blocks.\1.adaLN_modulation.linear.\2",
# 3. x_embedder mapping:
r"^x_embedder\.proj\.(.*)$": r"img_in.proj.\1",
# 4. Top-level time_text_embed mappings:
r"^time_text_embed\.timestep_embedder\.linear_1\.(.*)$": r"time_in.mlp.fc_in.\1",
r"^time_text_embed\.timestep_embedder\.linear_2\.(.*)$": r"time_in.mlp.fc_out.\1",
r"^time_text_embed\.guidance_embedder\.linear_1\.(.*)$": r"guidance_in.mlp.fc_in.\1",
r"^time_text_embed\.guidance_embedder\.linear_2\.(.*)$": r"guidance_in.mlp.fc_out.\1",
r"^time_text_embed\.text_embedder\.linear_1\.(.*)$": r"vector_in.fc_in.\1",
r"^time_text_embed\.text_embedder\.linear_2\.(.*)$": r"vector_in.fc_out.\1",
# 5. transformer_blocks mapping:
r"^transformer_blocks\.(\d+)\.norm1\.linear\.(.*)$": r"double_blocks.\1.img_mod.linear.\2",
r"^transformer_blocks\.(\d+)\.norm1_context\.linear\.(.*)$": r"double_blocks.\1.txt_mod.linear.\2",
r"^transformer_blocks\.(\d+)\.attn\.norm_q\.(.*)$": r"double_blocks.\1.img_attn_q_norm.\2",
r"^transformer_blocks\.(\d+)\.attn\.norm_k\.(.*)$": r"double_blocks.\1.img_attn_k_norm.\2",
r"^transformer_blocks\.(\d+)\.attn\.to_q\.(.*)$": (
r"double_blocks.\1.img_attn_qkv.\2",
0,
3,
),
r"^transformer_blocks\.(\d+)\.attn\.to_k\.(.*)$": (
r"double_blocks.\1.img_attn_qkv.\2",
1,
3,
),
r"^transformer_blocks\.(\d+)\.attn\.to_v\.(.*)$": (
r"double_blocks.\1.img_attn_qkv.\2",
2,
3,
),
r"^transformer_blocks\.(\d+)\.attn\.add_q_proj\.(.*)$": (
r"double_blocks.\1.txt_attn_qkv.\2",
0,
3,
),
r"^transformer_blocks\.(\d+)\.attn\.add_k_proj\.(.*)$": (
r"double_blocks.\1.txt_attn_qkv.\2",
1,
3,
),
r"^transformer_blocks\.(\d+)\.attn\.add_v_proj\.(.*)$": (
r"double_blocks.\1.txt_attn_qkv.\2",
2,
3,
),
r"^transformer_blocks\.(\d+)\.attn\.to_out\.0\.(.*)$": r"double_blocks.\1.img_attn_proj.\2",
# Corrected: merge attn.to_add_out into the main projection.
r"^transformer_blocks\.(\d+)\.attn\.to_add_out\.(.*)$": r"double_blocks.\1.txt_attn_proj.\2",
r"^transformer_blocks\.(\d+)\.attn\.norm_added_q\.(.*)$": r"double_blocks.\1.txt_attn_q_norm.\2",
r"^transformer_blocks\.(\d+)\.attn\.norm_added_k\.(.*)$": r"double_blocks.\1.txt_attn_k_norm.\2",
r"^transformer_blocks\.(\d+)\.ff\.net\.0(?:\.proj)?\.(.*)$": r"double_blocks.\1.img_mlp.fc_in.\2",
r"^transformer_blocks\.(\d+)\.ff\.net\.2(?:\.proj)?\.(.*)$": r"double_blocks.\1.img_mlp.fc_out.\2",
r"^transformer_blocks\.(\d+)\.ff_context\.net\.0(?:\.proj)?\.(.*)$": r"double_blocks.\1.txt_mlp.fc_in.\2",
r"^transformer_blocks\.(\d+)\.ff_context\.net\.2(?:\.proj)?\.(.*)$": r"double_blocks.\1.txt_mlp.fc_out.\2",
# 6. single_transformer_blocks mapping:
r"^single_transformer_blocks\.(\d+)\.attn\.norm_q\.(.*)$": r"single_blocks.\1.q_norm.\2",
r"^single_transformer_blocks\.(\d+)\.attn\.norm_k\.(.*)$": r"single_blocks.\1.k_norm.\2",
r"^single_transformer_blocks\.(\d+)\.attn\.to_q\.(.*)$": (
r"single_blocks.\1.linear1.\2",
0,
4,
),
r"^single_transformer_blocks\.(\d+)\.attn\.to_k\.(.*)$": (
r"single_blocks.\1.linear1.\2",
1,
4,
),
r"^single_transformer_blocks\.(\d+)\.attn\.to_v\.(.*)$": (
r"single_blocks.\1.linear1.\2",
2,
4,
),
r"^single_transformer_blocks\.(\d+)\.proj_mlp\.(.*)$": (
r"single_blocks.\1.linear1.\2",
3,
4,
),
# Corrected: map proj_out to modulation.linear rather than a separate proj_out branch.
r"^single_transformer_blocks\.(\d+)\.proj_out\.(.*)$": r"single_blocks.\1.linear2.\2",
r"^single_transformer_blocks\.(\d+)\.norm\.linear\.(.*)$": r"single_blocks.\1.modulation.linear.\2",
# 7. Final layers mapping:
r"^norm_out\.linear\.(.*)$": r"final_layer.adaLN_modulation.linear.\1",
r"^proj_out\.(.*)$": r"final_layer.linear.\1",
}
)
# Reverse mapping for saving checkpoints: custom -> hf
reverse_param_names_mapping: dict = field(default_factory=lambda: {})
patch_size: int = 2
patch_size_t: int = 1
in_channels: int = 16
out_channels: int = 16
num_attention_heads: int = 24
attention_head_dim: int = 128
mlp_ratio: float = 4.0
num_layers: int = 20
num_single_layers: int = 40
num_refiner_layers: int = 2
rope_axes_dim: tuple[int, int, int] = (16, 56, 56)
guidance_embeds: bool = False
dtype: torch.dtype | None = None
text_embed_dim: int = 4096
pooled_projection_dim: int = 768
rope_theta: int = 256
qk_norm: str = "rms_norm"
exclude_lora_layers: list[str] = field(
default_factory=lambda: ["img_in", "txt_in", "time_in", "vector_in"]
)
def __post_init__(self):
super().__post_init__()
self.hidden_size: int = self.attention_head_dim * self.num_attention_heads
self.num_channels_latents: int = self.in_channels
@dataclass
class HunyuanVideoConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=HunyuanVideoArchConfig)
prefix: str = "Hunyuan"
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Tuple
from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig
@dataclass
class QwenImageArchConfig(DiTArchConfig):
patch_size: int = 1
in_channels: int = 64
out_channels: int | None = None
num_layers: int = 19
num_single_layers: int = 38
attention_head_dim: int = 128
num_attention_heads: int = 24
joint_attention_dim: int = 4096
pooled_projection_dim: int = 768
guidance_embeds: bool = False
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
def __post_init__(self):
super().__post_init__()
self.out_channels = self.out_channels or self.in_channels
self.hidden_size = self.num_attention_heads * self.attention_head_dim
self.num_channels_latents = self.out_channels
@dataclass
class QwenImageDitConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=QwenImageArchConfig)
prefix: str = "qwenimage"
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig
def is_transformer_blocks(n, m):
return "transformer_blocks" in n and n.split(".")[-1].isdigit()
@dataclass
class StepVideoArchConfig(DiTArchConfig):
_fsdp_shard_conditions: list = field(
default_factory=lambda: [is_transformer_blocks]
)
param_names_mapping: dict = field(
default_factory=lambda: {
# transformer block
r"^transformer_blocks\.(\d+)\.norm1\.(weight|bias)$": r"transformer_blocks.\1.norm1.norm.\2",
r"^transformer_blocks\.(\d+)\.norm2\.(weight|bias)$": r"transformer_blocks.\1.norm2.norm.\2",
r"^transformer_blocks\.(\d+)\.ff\.net\.0\.proj\.weight$": r"transformer_blocks.\1.ff.fc_in.weight",
r"^transformer_blocks\.(\d+)\.ff\.net\.2\.weight$": r"transformer_blocks.\1.ff.fc_out.weight",
# adanorm block
r"^adaln_single\.emb\.timestep_embedder\.linear_1\.(weight|bias)$": r"adaln_single.emb.mlp.fc_in.\1",
r"^adaln_single\.emb\.timestep_embedder\.linear_2\.(weight|bias)$": r"adaln_single.emb.mlp.fc_out.\1",
# caption projection
r"^caption_projection\.linear_1\.(weight|bias)$": r"caption_projection.fc_in.\1",
r"^caption_projection\.linear_2\.(weight|bias)$": r"caption_projection.fc_out.\1",
}
)
num_attention_heads: int = 48
attention_head_dim: int = 128
in_channels: int = 64
out_channels: int | None = 64
num_layers: int = 48
dropout: float = 0.0
patch_size: int = 1
norm_type: str = "ada_norm_single"
norm_elementwise_affine: bool = False
norm_eps: float = 1e-6
caption_channels: int | list[int] | tuple[int, ...] | None = field(
default_factory=lambda: [6144, 1024]
)
attention_type: str | None = "torch"
use_additional_conditions: bool | None = False
exclude_lora_layers: list[str] = field(default_factory=lambda: [])
def __post_init__(self):
self.hidden_size = self.num_attention_heads * self.attention_head_dim
self.out_channels = (
self.in_channels if self.out_channels is None else self.out_channels
)
self.num_channels_latents = self.out_channels
@dataclass
class StepVideoConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=StepVideoArchConfig)
prefix: str = "StepVideo"
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig
def is_blocks(n: str, m) -> bool:
return "blocks" in n and str.isdigit(n.split(".")[-1])
@dataclass
class WanVideoArchConfig(DiTArchConfig):
_fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks])
param_names_mapping: dict = field(
default_factory=lambda: {
r"^patch_embedding\.(.*)$": r"patch_embedding.proj.\1",
r"^condition_embedder\.text_embedder\.linear_1\.(.*)$": r"condition_embedder.text_embedder.fc_in.\1",
r"^condition_embedder\.text_embedder\.linear_2\.(.*)$": r"condition_embedder.text_embedder.fc_out.\1",
r"^condition_embedder\.time_embedder\.linear_1\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_in.\1",
r"^condition_embedder\.time_embedder\.linear_2\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_out.\1",
r"^condition_embedder\.time_proj\.(.*)$": r"condition_embedder.time_modulation.linear.\1",
r"^condition_embedder\.image_embedder\.ff\.net\.0\.proj\.(.*)$": r"condition_embedder.image_embedder.ff.fc_in.\1",
r"^condition_embedder\.image_embedder\.ff\.net\.2\.(.*)$": r"condition_embedder.image_embedder.ff.fc_out.\1",
r"^blocks\.(\d+)\.attn1\.to_q\.(.*)$": r"blocks.\1.to_q.\2",
r"^blocks\.(\d+)\.attn1\.to_k\.(.*)$": r"blocks.\1.to_k.\2",
r"^blocks\.(\d+)\.attn1\.to_v\.(.*)$": r"blocks.\1.to_v.\2",
r"^blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$": r"blocks.\1.to_out.\2",
r"^blocks\.(\d+)\.attn1\.norm_q\.(.*)$": r"blocks.\1.norm_q.\2",
r"^blocks\.(\d+)\.attn1\.norm_k\.(.*)$": r"blocks.\1.norm_k.\2",
r"^blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$": r"blocks.\1.attn2.to_out.\2",
r"^blocks\.(\d+)\.ffn\.net\.0\.proj\.(.*)$": r"blocks.\1.ffn.fc_in.\2",
r"^blocks\.(\d+)\.ffn\.net\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2",
r"^blocks\.(\d+)\.norm2\.(.*)$": r"blocks.\1.self_attn_residual_norm.norm.\2",
}
)
# Reverse mapping for saving checkpoints: custom -> hf
reverse_param_names_mapping: dict = field(default_factory=lambda: {})
# Some LoRA adapters use the original official layer names instead of hf layer names,
# so apply this before the param_names_mapping
lora_param_names_mapping: dict = field(
default_factory=lambda: {
r"^blocks\.(\d+)\.self_attn\.q\.(.*)$": r"blocks.\1.attn1.to_q.\2",
r"^blocks\.(\d+)\.self_attn\.k\.(.*)$": r"blocks.\1.attn1.to_k.\2",
r"^blocks\.(\d+)\.self_attn\.v\.(.*)$": r"blocks.\1.attn1.to_v.\2",
r"^blocks\.(\d+)\.self_attn\.o\.(.*)$": r"blocks.\1.attn1.to_out.0.\2",
r"^blocks\.(\d+)\.cross_attn\.q\.(.*)$": r"blocks.\1.attn2.to_q.\2",
r"^blocks\.(\d+)\.cross_attn\.k\.(.*)$": r"blocks.\1.attn2.to_k.\2",
r"^blocks\.(\d+)\.cross_attn\.v\.(.*)$": r"blocks.\1.attn2.to_v.\2",
r"^blocks\.(\d+)\.cross_attn\.o\.(.*)$": r"blocks.\1.attn2.to_out.0.\2",
r"^blocks\.(\d+)\.ffn\.0\.(.*)$": r"blocks.\1.ffn.fc_in.\2",
r"^blocks\.(\d+)\.ffn\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2",
}
)
patch_size: tuple[int, int, int] = (1, 2, 2)
text_len = 512
num_attention_heads: int = 40
attention_head_dim: int = 128
in_channels: int = 16
out_channels: int = 16
text_dim: int = 4096
freq_dim: int = 256
ffn_dim: int = 13824
num_layers: int = 40
cross_attn_norm: bool = True
qk_norm: str = "rms_norm_across_heads"
eps: float = 1e-6
image_dim: int | None = None
added_kv_proj_dim: int | None = None
rope_max_seq_len: int = 1024
pos_embed_seq_len: int | None = None
exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"])
# Wan MoE
boundary_ratio: float | None = None
# Causal Wan
local_attn_size: int = (
-1
) # Window size for temporal local attention (-1 indicates global attention)
sink_size: int = (
0 # Size of the attention sink, we keep the first `sink_size` frames unchanged when rolling the KV cache
)
num_frames_per_block: int = 3
sliding_window_num_frames: int = 21
def __post_init__(self):
super().__post_init__()
self.out_channels = self.out_channels or self.in_channels
self.hidden_size = self.num_attention_heads * self.attention_head_dim
self.num_channels_latents = self.out_channels
@dataclass
class WanVideoConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=WanVideoArchConfig)
prefix: str = "Wan"
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from sglang.multimodal_gen.configs.models.encoders.base import (
BaseEncoderOutput,
EncoderConfig,
ImageEncoderConfig,
TextEncoderConfig,
)
from sglang.multimodal_gen.configs.models.encoders.clip import (
CLIPTextConfig,
CLIPVisionConfig,
)
from sglang.multimodal_gen.configs.models.encoders.llama import LlamaConfig
from sglang.multimodal_gen.configs.models.encoders.t5 import T5Config
__all__ = [
"EncoderConfig",
"TextEncoderConfig",
"ImageEncoderConfig",
"BaseEncoderOutput",
"CLIPTextConfig",
"CLIPVisionConfig",
"LlamaConfig",
"T5Config",
]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Any
import torch
from sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig
from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
@dataclass
class EncoderArchConfig(ArchConfig):
architectures: list[str] = field(default_factory=lambda: [])
_supported_attention_backends: set[AttentionBackendEnum] = field(
default_factory=lambda: {
AttentionBackendEnum.FA3,
AttentionBackendEnum.TORCH_SDPA,
}
)
output_hidden_states: bool = False
use_return_dict: bool = True
@dataclass
class TextEncoderArchConfig(EncoderArchConfig):
vocab_size: int = 0
hidden_size: int = 0
num_hidden_layers: int = 0
num_attention_heads: int = 0
pad_token_id: int = 0
eos_token_id: int = 0
text_len: int = 0
hidden_state_skip_layer: int = 0
decoder_start_token_id: int = 0
output_past: bool = True
scalable_attention: bool = True
tie_word_embeddings: bool = False
stacked_params_mapping: list[tuple[str, str, str]] = field(
default_factory=list
) # mapping from huggingface weight names to custom names
tokenizer_kwargs: dict[str, Any] = field(default_factory=dict)
_fsdp_shard_conditions: list = field(default_factory=lambda: [])
def __post_init__(self) -> None:
self.tokenizer_kwargs = {
"truncation": True,
"max_length": self.text_len,
"return_tensors": "pt",
}
@dataclass
class ImageEncoderArchConfig(EncoderArchConfig):
pass
@dataclass
class BaseEncoderOutput:
last_hidden_state: torch.FloatTensor | None = None
pooler_output: torch.FloatTensor | None = None
hidden_states: tuple[torch.FloatTensor, ...] | None = None
attentions: tuple[torch.FloatTensor, ...] | None = None
attention_mask: torch.Tensor | None = None
@dataclass
class EncoderConfig(ModelConfig):
arch_config: ArchConfig = field(default_factory=EncoderArchConfig)
prefix: str = ""
quant_config: QuantizationConfig | None = None
lora_config: Any | None = None
@dataclass
class TextEncoderConfig(EncoderConfig):
arch_config: ArchConfig = field(default_factory=TextEncoderArchConfig)
@dataclass
class ImageEncoderConfig(EncoderConfig):
arch_config: ArchConfig = field(default_factory=ImageEncoderArchConfig)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models.encoders.base import (
ImageEncoderArchConfig,
ImageEncoderConfig,
TextEncoderArchConfig,
TextEncoderConfig,
)
def _is_transformer_layer(n: str, m) -> bool:
return "layers" in n and str.isdigit(n.split(".")[-1])
def _is_embeddings(n: str, m) -> bool:
return n.endswith("embeddings")
@dataclass
class CLIPTextArchConfig(TextEncoderArchConfig):
vocab_size: int = 49408
hidden_size: int = 512
intermediate_size: int = 2048
projection_dim: int = 512
num_hidden_layers: int = 12
num_attention_heads: int = 8
max_position_embeddings: int = 77
hidden_act: str = "quick_gelu"
layer_norm_eps: float = 1e-5
dropout: float = 0.0
attention_dropout: float = 0.0
initializer_range: float = 0.02
initializer_factor: float = 1.0
pad_token_id: int = 1
bos_token_id: int = 49406
eos_token_id: int = 49407
text_len: int = 77
stacked_params_mapping: list[tuple[str, str, str]] = field(
default_factory=lambda: [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
)
_fsdp_shard_conditions: list = field(
default_factory=lambda: [_is_transformer_layer, _is_embeddings]
)
@dataclass
class CLIPVisionArchConfig(ImageEncoderArchConfig):
hidden_size: int = 768
intermediate_size: int = 3072
projection_dim: int = 512
num_hidden_layers: int = 12
num_attention_heads: int = 12
num_channels: int = 3
image_size: int = 224
patch_size: int = 32
hidden_act: str = "quick_gelu"
layer_norm_eps: float = 1e-5
dropout: float = 0.0
attention_dropout: float = 0.0
initializer_range: float = 0.02
initializer_factor: float = 1.0
stacked_params_mapping: list[tuple[str, str, str]] = field(
default_factory=lambda: [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
)
@dataclass
class CLIPTextConfig(TextEncoderConfig):
arch_config: TextEncoderArchConfig = field(default_factory=CLIPTextArchConfig)
num_hidden_layers_override: int | None = None
require_post_norm: bool | None = None
prefix: str = "clip"
@dataclass
class CLIPVisionConfig(ImageEncoderConfig):
arch_config: ImageEncoderArchConfig = field(default_factory=CLIPVisionArchConfig)
num_hidden_layers_override: int | None = None
require_post_norm: bool | None = None
prefix: str = "clip"
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models.encoders.base import (
TextEncoderArchConfig,
TextEncoderConfig,
)
def _is_transformer_layer(n: str, m) -> bool:
return "layers" in n and str.isdigit(n.split(".")[-1])
def _is_embeddings(n: str, m) -> bool:
return n.endswith("embed_tokens")
def _is_final_norm(n: str, m) -> bool:
return n.endswith("norm")
@dataclass
class LlamaArchConfig(TextEncoderArchConfig):
vocab_size: int = 32000
hidden_size: int = 4096
intermediate_size: int = 11008
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: int | None = None
hidden_act: str = "silu"
max_position_embeddings: int = 2048
initializer_range: float = 0.02
rms_norm_eps: float = 1e-6
use_cache: bool = True
pad_token_id: int = 0
bos_token_id: int = 1
eos_token_id: int = 2
pretraining_tp: int = 1
tie_word_embeddings: bool = False
rope_theta: float = 10000.0
rope_scaling: float | None = None
attention_bias: bool = False
attention_dropout: float = 0.0
mlp_bias: bool = False
head_dim: int | None = None
hidden_state_skip_layer: int = 2
text_len: int = 256
stacked_params_mapping: list[tuple[str, str, str]] = field(
default_factory=lambda: [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0), # type: ignore
(".gate_up_proj", ".up_proj", 1), # type: ignore
]
)
_fsdp_shard_conditions: list = field(
default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_norm]
)
@dataclass
class LlamaConfig(TextEncoderConfig):
arch_config: TextEncoderArchConfig = field(default_factory=LlamaArchConfig)
prefix: str = "llama"
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models.encoders.base import (
TextEncoderArchConfig,
TextEncoderConfig,
)
def _is_transformer_layer(n: str, m) -> bool:
return "layers" in n and str.isdigit(n.split(".")[-1])
def _is_embeddings(n: str, m) -> bool:
return n.endswith("embed_tokens")
def _is_final_norm(n: str, m) -> bool:
return n.endswith("norm")
@dataclass
class QwenImageArchConfig(TextEncoderArchConfig):
vocab_size: int = 32000
hidden_size: int = 4096
intermediate_size: int = 11008
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: int | None = None
hidden_act: str = "silu"
max_position_embeddings: int = 2048
initializer_range: float = 0.02
rms_norm_eps: float = 1e-6
use_cache: bool = True
pad_token_id: int = -1
eos_token_id: int = 2
pretraining_tp: int = 1
tie_word_embeddings: bool = False
rope_theta: float = 10000.0
rope_scaling: float | None = None
attention_bias: bool = False
attention_dropout: float = 0.0
mlp_bias: bool = False
head_dim: int | None = None
hidden_state_skip_layer: int = 2
text_len: int = 256
stacked_params_mapping: list[tuple[str, str, str]] = field(
default_factory=lambda: [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0), # type: ignore
(".gate_up_proj", ".up_proj", 1), # type: ignore
]
)
_fsdp_shard_conditions: list = field(
default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_norm]
)
@dataclass
class Qwen2_5VLConfig(TextEncoderConfig):
arch_config: TextEncoderArchConfig = field(default_factory=QwenImageArchConfig)
# prefix: str = "qwen_image"
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models.encoders.base import (
TextEncoderArchConfig,
TextEncoderConfig,
)
def _is_transformer_layer(n: str, m) -> bool:
return "block" in n and str.isdigit(n.split(".")[-1])
def _is_embeddings(n: str, m) -> bool:
return n.endswith("shared")
def _is_final_layernorm(n: str, m) -> bool:
return n.endswith("final_layer_norm")
@dataclass
class T5ArchConfig(TextEncoderArchConfig):
vocab_size: int = 32128
d_model: int = 512
d_kv: int = 64
d_ff: int = 2048
num_layers: int = 6
num_decoder_layers: int | None = None
num_heads: int = 8
relative_attention_num_buckets: int = 32
relative_attention_max_distance: int = 128
dropout_rate: float = 0.1
layer_norm_epsilon: float = 1e-6
initializer_factor: float = 1.0
feed_forward_proj: str = "relu"
dense_act_fn: str = ""
is_gated_act: bool = False
is_encoder_decoder: bool = True
use_cache: bool = True
pad_token_id: int = 0
eos_token_id: int = 1
classifier_dropout: float = 0.0
text_len: int = 512
stacked_params_mapping: list[tuple[str, str, str]] = field(
default_factory=lambda: [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q", "q"),
(".qkv_proj", ".k", "k"),
(".qkv_proj", ".v", "v"),
]
)
_fsdp_shard_conditions: list = field(
default_factory=lambda: [
_is_transformer_layer,
_is_embeddings,
_is_final_layernorm,
]
)
# Referenced from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py
def __post_init__(self):
super().__post_init__()
act_info = self.feed_forward_proj.split("-")
self.dense_act_fn: str = act_info[-1]
self.is_gated_act: bool = act_info[0] == "gated"
if self.feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new"
self.tokenizer_kwargs = {
"padding": "max_length",
"truncation": True,
"max_length": self.text_len,
"add_special_tokens": True,
"return_attention_mask": True,
"return_tensors": "pt",
}
@dataclass
class T5Config(TextEncoderConfig):
arch_config: TextEncoderArchConfig = field(default_factory=T5ArchConfig)
prefix: str = "t5"
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from sglang.multimodal_gen.configs.models.vaes.hunyuanvae import HunyuanVAEConfig
from sglang.multimodal_gen.configs.models.vaes.stepvideovae import StepVideoVAEConfig
from sglang.multimodal_gen.configs.models.vaes.wanvae import WanVAEConfig
__all__ = [
"HunyuanVAEConfig",
"WanVAEConfig",
"StepVideoVAEConfig",
]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import argparse
import dataclasses
from dataclasses import dataclass, field
from typing import Any
import torch
from sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig
from sglang.multimodal_gen.runtime.models.vision_utils import get_default_height_width
from sglang.multimodal_gen.utils import StoreBoolean
@dataclass
class VAEArchConfig(ArchConfig):
scaling_factor: float | torch.Tensor = 0
temporal_compression_ratio: int = 4
# or vae_scale_factor?
spatial_compression_ratio: int = 8
@dataclass
class VAEConfig(ModelConfig):
arch_config: VAEArchConfig = field(default_factory=VAEArchConfig)
# sgl-diffusionVAE-specific parameters
load_encoder: bool = True
load_decoder: bool = True
tile_sample_min_height: int = 256
tile_sample_min_width: int = 256
tile_sample_min_num_frames: int = 16
tile_sample_stride_height: int = 192
tile_sample_stride_width: int = 192
tile_sample_stride_num_frames: int = 12
blend_num_frames: int = 0
use_tiling: bool = True
use_temporal_tiling: bool = True
use_parallel_tiling: bool = True
use_temporal_scaling_frames: bool = True
def __post_init__(self):
self.blend_num_frames = (
self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
)
def post_init(self):
pass
# returns width, height
def calculate_dimensions(
self, image, vae_scale_factor, width, height
) -> tuple[int, int]:
height, width = get_default_height_width(image, vae_scale_factor, height, width)
return width, height
@staticmethod
def add_cli_args(parser: Any, prefix: str = "vae-config") -> Any:
"""Add CLI arguments for VAEConfig fields"""
parser.add_argument(
f"--{prefix}.load-encoder",
action=StoreBoolean,
dest=f"{prefix.replace('-', '_')}.load_encoder",
default=VAEConfig.load_encoder,
help="Whether to load the VAE encoder",
)
parser.add_argument(
f"--{prefix}.load-decoder",
action=StoreBoolean,
dest=f"{prefix.replace('-', '_')}.load_decoder",
default=VAEConfig.load_decoder,
help="Whether to load the VAE decoder",
)
parser.add_argument(
f"--{prefix}.tile-sample-min-height",
type=int,
dest=f"{prefix.replace('-', '_')}.tile_sample_min_height",
default=VAEConfig.tile_sample_min_height,
help="Minimum height for VAE tile sampling",
)
parser.add_argument(
f"--{prefix}.tile-sample-min-width",
type=int,
dest=f"{prefix.replace('-', '_')}.tile_sample_min_width",
default=VAEConfig.tile_sample_min_width,
help="Minimum width for VAE tile sampling",
)
parser.add_argument(
f"--{prefix}.tile-sample-min-num-frames",
type=int,
dest=f"{prefix.replace('-', '_')}.tile_sample_min_num_frames",
default=VAEConfig.tile_sample_min_num_frames,
help="Minimum number of frames for VAE tile sampling",
)
parser.add_argument(
f"--{prefix}.tile-sample-stride-height",
type=int,
dest=f"{prefix.replace('-', '_')}.tile_sample_stride_height",
default=VAEConfig.tile_sample_stride_height,
help="Stride height for VAE tile sampling",
)
parser.add_argument(
f"--{prefix}.tile-sample-stride-width",
type=int,
dest=f"{prefix.replace('-', '_')}.tile_sample_stride_width",
default=VAEConfig.tile_sample_stride_width,
help="Stride width for VAE tile sampling",
)
parser.add_argument(
f"--{prefix}.tile-sample-stride-num-frames",
type=int,
dest=f"{prefix.replace('-', '_')}.tile_sample_stride_num_frames",
default=VAEConfig.tile_sample_stride_num_frames,
help="Stride number of frames for VAE tile sampling",
)
parser.add_argument(
f"--{prefix}.blend-num-frames",
type=int,
dest=f"{prefix.replace('-', '_')}.blend_num_frames",
default=VAEConfig.blend_num_frames,
help="Number of frames to blend for VAE tile sampling",
)
parser.add_argument(
f"--{prefix}.use-tiling",
action=StoreBoolean,
dest=f"{prefix.replace('-', '_')}.use_tiling",
default=VAEConfig.use_tiling,
help="Whether to use tiling for VAE",
)
parser.add_argument(
f"--{prefix}.use-temporal-tiling",
action=StoreBoolean,
dest=f"{prefix.replace('-', '_')}.use_temporal_tiling",
default=VAEConfig.use_temporal_tiling,
help="Whether to use temporal tiling for VAE",
)
parser.add_argument(
f"--{prefix}.use-parallel-tiling",
action=StoreBoolean,
dest=f"{prefix.replace('-', '_')}.use_parallel_tiling",
default=VAEConfig.use_parallel_tiling,
help="Whether to use parallel tiling for VAE",
)
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "VAEConfig":
kwargs = {}
for attr in dataclasses.fields(cls):
value = getattr(args, attr.name, None)
if value is not None:
kwargs[attr.name] = value
return cls(**kwargs)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig
@dataclass
class FluxVAEArchConfig(VAEArchConfig):
spatial_compression_ratio: int = 1
base_dim: int = 96
decoder_base_dim: int | None = None
z_dim: int = 16
dim_mult: tuple[int, ...] = (1, 2, 4, 4)
num_res_blocks: int = 2
attn_scales: tuple[float, ...] = ()
temperal_downsample: tuple[bool, ...] = (False, True, True)
dropout: float = 0.0
is_residual: bool = False
in_channels: int = 3
out_channels: int = 3
patch_size: int | None = None
scale_factor_temporal: int = 4
scale_factor_spatial: int = 8
clip_output: bool = True
@dataclass
class FluxVAEConfig(VAEConfig):
arch_config: FluxVAEArchConfig = field(default_factory=FluxVAEArchConfig)
use_feature_cache: bool = True
use_tiling: bool = False
use_temporal_tiling: bool = False
use_parallel_tiling: bool = False
def __post_init__(self):
self.blend_num_frames = (
self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
) * 2
def post_init(self):
self.arch_config.vae_scale_factor = 2 ** (
len(self.arch_config.block_out_channels) - 1
)
self.arch_config.spatial_compression_ratio = self.arch_config.vae_scale_factor
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig
@dataclass
class HunyuanVAEArchConfig(VAEArchConfig):
in_channels: int = 3
out_channels: int = 3
latent_channels: int = 16
down_block_types: tuple[str, ...] = (
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
)
up_block_types: tuple[str, ...] = (
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
)
block_out_channels: tuple[int, ...] = (128, 256, 512, 512)
layers_per_block: int = 2
act_fn: str = "silu"
norm_num_groups: int = 32
scaling_factor: float = 0.476986
spatial_compression_ratio: int = 8
temporal_compression_ratio: int = 4
mid_block_add_attention: bool = True
def __post_init__(self):
self.spatial_compression_ratio: int = 2 ** (len(self.block_out_channels) - 1)
@dataclass
class HunyuanVAEConfig(VAEConfig):
arch_config: VAEArchConfig = field(default_factory=HunyuanVAEArchConfig)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions
from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig
@dataclass
class QwenImageVAEArchConfig(VAEArchConfig):
spatial_compression_ratio: int = 1
base_dim: int = 96
decoder_base_dim: int | None = None
z_dim: int = 16
dim_mult: tuple[int, ...] = (1, 2, 4, 4)
num_res_blocks: int = 2
attn_scales: tuple[float, ...] = ()
temperal_downsample: tuple[bool, ...] = (False, True, True)
dropout: float = 0.0
is_residual: bool = False
in_channels: int = 3
out_channels: int = 3
patch_size: int | None = None
scale_factor_temporal: int = 4
scale_factor_spatial: int = 8
clip_output: bool = True
def __post_init__(self):
self.vae_scale_factor = 2 ** len(self.temperal_downsample)
@dataclass
class QwenImageVAEConfig(VAEConfig):
arch_config: QwenImageVAEArchConfig = field(default_factory=QwenImageVAEArchConfig)
use_feature_cache: bool = True
use_tiling: bool = False
use_temporal_tiling: bool = False
use_parallel_tiling: bool = False
def calculate_dimensions(self, image, vae_scale_factor, width, height):
width = image.size[0]
height = image.size[1]
width, height, _ = calculate_dimensions(1024 * 1024, width / height)
return width, height
def __post_init__(self):
self.blend_num_frames = (
self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
) * 2
def post_init(self):
self.arch_config.vae_scale_factor = 2 ** (
len(self.arch_config.temperal_downsample)
)
self.arch_config.spatial_compression_ratio = self.arch_config.vae_scale_factor
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig
@dataclass
class StepVideoVAEArchConfig(VAEArchConfig):
in_channels: int = 3
out_channels: int = 3
z_channels: int = 64
num_res_blocks: int = 2
version: int = 2
frame_len: int = 17
world_size: int = 1
spatial_compression_ratio: int = 16
temporal_compression_ratio: int = 8
scaling_factor: float = 1.0
@dataclass
class StepVideoVAEConfig(VAEConfig):
arch_config: VAEArchConfig = field(default_factory=StepVideoVAEArchConfig)
use_tiling: bool = False
use_temporal_tiling: bool = False
use_parallel_tiling: bool = False
use_temporal_scaling_frames: bool = False
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
import torch
from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig
@dataclass
class WanVAEArchConfig(VAEArchConfig):
base_dim: int = 96
decoder_base_dim: int | None = None
z_dim: int = 16
dim_mult: tuple[int, ...] = (1, 2, 4, 4)
num_res_blocks: int = 2
attn_scales: tuple[float, ...] = ()
temperal_downsample: tuple[bool, ...] = (False, True, True)
dropout: float = 0.0
latents_mean: tuple[float, ...] = (
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
)
latents_std: tuple[float, ...] = (
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
)
is_residual: bool = False
in_channels: int = 3
out_channels: int = 3
patch_size: int | None = None
scale_factor_temporal: int = 4
scale_factor_spatial: int = 8
clip_output: bool = True
def __post_init__(self):
self.scaling_factor: torch.tensor = 1.0 / torch.tensor(self.latents_std).view(
1, self.z_dim, 1, 1, 1
)
self.shift_factor: torch.tensor = torch.tensor(self.latents_mean).view(
1, self.z_dim, 1, 1, 1
)
self.temporal_compression_ratio = self.scale_factor_temporal
self.spatial_compression_ratio = self.scale_factor_spatial
@dataclass
class WanVAEConfig(VAEConfig):
arch_config: WanVAEArchConfig = field(default_factory=WanVAEArchConfig)
use_feature_cache: bool = True
use_tiling: bool = False
use_temporal_tiling: bool = False
use_parallel_tiling: bool = False
def __post_init__(self):
self.blend_num_frames = (
self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
) * 2
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from sglang.multimodal_gen.configs.pipelines.base import (
PipelineConfig,
SlidingTileAttnConfig,
)
from sglang.multimodal_gen.configs.pipelines.flux import FluxPipelineConfig
from sglang.multimodal_gen.configs.pipelines.hunyuan import (
FastHunyuanConfig,
HunyuanConfig,
)
from sglang.multimodal_gen.configs.pipelines.registry import (
get_pipeline_config_cls_from_name,
)
from sglang.multimodal_gen.configs.pipelines.stepvideo import StepVideoT2VConfig
from sglang.multimodal_gen.configs.pipelines.wan import (
SelfForcingWanT2V480PConfig,
WanI2V480PConfig,
WanI2V720PConfig,
WanT2V480PConfig,
WanT2V720PConfig,
)
__all__ = [
"HunyuanConfig",
"FastHunyuanConfig",
"FluxPipelineConfig",
"PipelineConfig",
"SlidingTileAttnConfig",
"WanT2V480PConfig",
"WanI2V480PConfig",
"WanT2V720PConfig",
"WanI2V720PConfig",
"StepVideoT2VConfig",
"SelfForcingWanT2V480PConfig",
"get_pipeline_config_cls_from_name",
]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import json
from collections.abc import Callable
from dataclasses import asdict, dataclass, field, fields
from enum import Enum
from typing import Any, cast
import torch
from diffusers.image_processor import VaeImageProcessor
from sglang.multimodal_gen.configs.models import (
DiTConfig,
EncoderConfig,
ModelConfig,
VAEConfig,
)
from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput
from sglang.multimodal_gen.configs.utils import update_config_from_args
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import (
FlexibleArgumentParser,
StoreBoolean,
shallow_asdict,
)
logger = init_logger(__name__)
class STA_Mode(str, Enum):
"""STA (Sliding Tile Attention) modes."""
STA_INFERENCE = "STA_inference"
STA_SEARCHING = "STA_searching"
STA_TUNING = "STA_tuning"
STA_TUNING_CFG = "STA_tuning_cfg"
NONE = None
def preprocess_text(prompt: str) -> str:
return prompt
def postprocess_text(output: BaseEncoderOutput, _text_inputs) -> torch.tensor:
raise NotImplementedError
# config for a single pipeline
@dataclass
class PipelineConfig:
"""Base configuration for all pipeline architectures."""
model_path: str = ""
pipeline_config_path: str | None = None
is_image_gen: bool = False
# generation parameters
# controls the timestep embedding generation
should_use_guidance: bool = True
embedded_cfg_scale: float = 6.0
flow_shift: float | None = None
disable_autocast: bool = False
# Model configuration
dit_config: DiTConfig = field(default_factory=DiTConfig)
dit_precision: str = "bf16"
# VAE configuration
vae_config: VAEConfig = field(default_factory=VAEConfig)
vae_precision: str = "fp32"
vae_tiling: bool = True
vae_sp: bool = True
# Image encoder configuration
image_encoder_config: EncoderConfig = field(default_factory=EncoderConfig)
image_encoder_precision: str = "fp32"
# Text encoder configuration
DEFAULT_TEXT_ENCODER_PRECISIONS = ("fp32",)
text_encoder_configs: tuple[EncoderConfig, ...] = field(
default_factory=lambda: (EncoderConfig(),)
)
# See PRECISION_TO_TYPE for detailed mapping
text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32",))
text_encoder_extra_args: list[dict] = field(default_factory=lambda: [{}])
# image encoding
image_encoder_extra_args: dict = field(default_factory=lambda: {})
def postprocess_image(self, image):
return image.last_hidden_state
preprocess_text_funcs: tuple[Callable[[str], str], ...] = field(
default_factory=lambda: (preprocess_text,)
)
postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.tensor], ...] = (
field(default_factory=lambda: (postprocess_text,))
)
# StepVideo specific parameters
pos_magic: str | None = None
neg_magic: str | None = None
timesteps_scale: bool | None = None
# STA (Sliding Tile Attention) parameters
mask_strategy_file_path: str | None = None
STA_mode: STA_Mode = STA_Mode.STA_INFERENCE
skip_time_steps: int = 15
# DMD parameters
dmd_denoising_steps: list[int] | None = field(default=None)
# Wan2.2 TI2V parameters
ti2v_task: bool = False
i2v_task: bool = False
ti2i_task: bool = False
boundary_ratio: float | None = None
# Compilation
# enable_torch_compile: bool = False
def slice_noise_pred(self, noise, latents):
return noise
def set_width_and_height(self, width, height, image):
"""
image: input image
"""
return width, height
# called in ImageEncodingStage, preprocess the image
def preprocess_image(self, image, image_processor: VaeImageProcessor):
return image
def prepare_latent_shape(self, batch, batch_size, num_frames):
height = batch.height // self.vae_config.arch_config.spatial_compression_ratio
width = batch.width // self.vae_config.arch_config.spatial_compression_ratio
# Calculate latent shape
shape = (
batch_size,
self.dit_config.num_channels_latents,
num_frames,
height,
width,
)
return shape
# called after latents are prepared
def pack_latents(self, latents, batch_size, batch):
return latents
def get_pos_prompt_embeds(self, batch):
return batch.prompt_embeds
def get_neg_prompt_embeds(self, batch):
return batch.negative_prompt_embeds
def post_denoising_loop(self, latents, batch):
return latents
def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype):
return {}
def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):
return {}
@staticmethod
def add_cli_args(
parser: FlexibleArgumentParser, prefix: str = ""
) -> FlexibleArgumentParser:
prefix_with_dot = f"{prefix}." if (prefix.strip() != "") else ""
# model_path will be conflicting with the model_path in ServerArgs,
# so we add it separately if prefix is not empty
if prefix_with_dot != "":
parser.add_argument(
f"--{prefix_with_dot}model-path",
type=str,
dest=f"{prefix_with_dot.replace('-', '_')}model_path",
default=PipelineConfig.model_path,
help="Path to the pretrained model",
)
parser.add_argument(
f"--{prefix_with_dot}pipeline-config-path",
type=str,
dest=f"{prefix_with_dot.replace('-', '_')}pipeline_config_path",
default=PipelineConfig.pipeline_config_path,
help="Path to the pipeline config",
)
parser.add_argument(
f"--{prefix_with_dot}embedded-cfg-scale",
type=float,
dest=f"{prefix_with_dot.replace('-', '_')}embedded_cfg_scale",
default=PipelineConfig.embedded_cfg_scale,
help="Embedded CFG scale",
)
parser.add_argument(
f"--{prefix_with_dot}flow-shift",
type=float,
dest=f"{prefix_with_dot.replace('-', '_')}flow_shift",
default=PipelineConfig.flow_shift,
help="Flow shift parameter",
)
# DiT configuration
parser.add_argument(
f"--{prefix_with_dot}dit-precision",
type=str,
dest=f"{prefix_with_dot.replace('-', '_')}dit_precision",
default=PipelineConfig.dit_precision,
choices=["fp32", "fp16", "bf16"],
help="Precision for the DiT model",
)
# VAE configuration
parser.add_argument(
f"--{prefix_with_dot}vae-precision",
type=str,
dest=f"{prefix_with_dot.replace('-', '_')}vae_precision",
default=PipelineConfig.vae_precision,
choices=["fp32", "fp16", "bf16"],
help="Precision for VAE",
)
parser.add_argument(
f"--{prefix_with_dot}vae-tiling",
action=StoreBoolean,
dest=f"{prefix_with_dot.replace('-', '_')}vae_tiling",
default=PipelineConfig.vae_tiling,
help="Enable VAE tiling",
)
parser.add_argument(
f"--{prefix_with_dot}vae-sp",
action=StoreBoolean,
dest=f"{prefix_with_dot.replace('-', '_')}vae_sp",
help="Enable VAE spatial parallelism",
)
# Text encoder configuration
parser.add_argument(
f"--{prefix_with_dot}text-encoder-precisions",
nargs="+",
type=str,
dest=f"{prefix_with_dot.replace('-', '_')}text_encoder_precisions",
default=PipelineConfig.DEFAULT_TEXT_ENCODER_PRECISIONS,
choices=["fp32", "fp16", "bf16"],
help="Precision for each text encoder",
)
# Image encoder configuration
parser.add_argument(
f"--{prefix_with_dot}image-encoder-precision",
type=str,
dest=f"{prefix_with_dot.replace('-', '_')}image_encoder_precision",
default=PipelineConfig.image_encoder_precision,
choices=["fp32", "fp16", "bf16"],
help="Precision for image encoder",
)
parser.add_argument(
f"--{prefix_with_dot}pos_magic",
type=str,
dest=f"{prefix_with_dot.replace('-', '_')}pos_magic",
default=PipelineConfig.pos_magic,
help="Positive magic prompt for sampling, used in stepvideo",
)
parser.add_argument(
f"--{prefix_with_dot}neg_magic",
type=str,
dest=f"{prefix_with_dot.replace('-', '_')}neg_magic",
default=PipelineConfig.neg_magic,
help="Negative magic prompt for sampling, used in stepvideo",
)
parser.add_argument(
f"--{prefix_with_dot}timesteps_scale",
type=bool,
dest=f"{prefix_with_dot.replace('-', '_')}timesteps_scale",
default=PipelineConfig.timesteps_scale,
help="Bool for applying scheduler scale in set_timesteps, used in stepvideo",
)
# DMD parameters
parser.add_argument(
f"--{prefix_with_dot}dmd-denoising-steps",
type=parse_int_list,
default=PipelineConfig.dmd_denoising_steps,
help="Comma-separated list of denoising steps (e.g., '1000,757,522')",
)
# Add VAE configuration arguments
from sglang.multimodal_gen.configs.models.vaes.base import VAEConfig
VAEConfig.add_cli_args(parser, prefix=f"{prefix_with_dot}vae-config")
# Add DiT configuration arguments
from sglang.multimodal_gen.configs.models.dits.base import DiTConfig
DiTConfig.add_cli_args(parser, prefix=f"{prefix_with_dot}dit-config")
return parser
def update_config_from_dict(self, args: dict[str, Any], prefix: str = "") -> None:
prefix_with_dot = f"{prefix}." if (prefix.strip() != "") else ""
update_config_from_args(self, args, prefix, pop_args=True)
update_config_from_args(
self.vae_config, args, f"{prefix_with_dot}vae_config", pop_args=True
)
update_config_from_args(
self.dit_config, args, f"{prefix_with_dot}dit_config", pop_args=True
)
@classmethod
def from_pretrained(cls, model_path: str) -> "PipelineConfig":
"""
use the pipeline class setting from model_path to match the pipeline config
"""
from sglang.multimodal_gen.configs.pipelines.registry import (
get_pipeline_config_cls_from_name,
)
pipeline_config_cls = get_pipeline_config_cls_from_name(model_path)
return cast(PipelineConfig, pipeline_config_cls(model_path=model_path))
@classmethod
def from_kwargs(
cls, kwargs: dict[str, Any], config_cli_prefix: str = ""
) -> "PipelineConfig":
"""
Load PipelineConfig from kwargs Dictionary.
kwargs: dictionary of kwargs
config_cli_prefix: prefix of CLI arguments for this PipelineConfig instance
"""
from sglang.multimodal_gen.configs.pipelines.registry import (
get_pipeline_config_cls_from_name,
)
prefix_with_dot = (
f"{config_cli_prefix}." if (config_cli_prefix.strip() != "") else ""
)
model_path: str | None = kwargs.get(
prefix_with_dot + "model_path", None
) or kwargs.get("model_path")
pipeline_config_or_path: str | PipelineConfig | dict[str, Any] | None = (
kwargs.get(prefix_with_dot + "pipeline_config", None)
or kwargs.get("pipeline_config")
)
if model_path is None:
raise ValueError("model_path is required in kwargs")
# 1. Get the pipeline config class from the registry
pipeline_config_cls = get_pipeline_config_cls_from_name(model_path)
# 2. Instantiate PipelineConfig
if pipeline_config_cls is None:
logger.warning(
"Couldn't find pipeline config for %s. Using the default pipeline config.",
model_path,
)
pipeline_config = cls()
else:
pipeline_config = pipeline_config_cls()
# 3. Load PipelineConfig from a json file or a PipelineConfig object if provided
if isinstance(pipeline_config_or_path, str):
pipeline_config.load_from_json(pipeline_config_or_path)
kwargs[prefix_with_dot + "pipeline_config_path"] = pipeline_config_or_path
elif isinstance(pipeline_config_or_path, PipelineConfig):
pipeline_config = pipeline_config_or_path
elif isinstance(pipeline_config_or_path, dict):
pipeline_config.update_pipeline_config(pipeline_config_or_path)
# 4. Update PipelineConfig from CLI arguments if provided
kwargs[prefix_with_dot + "model_path"] = model_path
pipeline_config.update_config_from_dict(kwargs, config_cli_prefix)
return pipeline_config
def check_pipeline_config(self) -> None:
if self.vae_sp and not self.vae_tiling:
raise ValueError(
"Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True."
)
if len(self.text_encoder_configs) != len(self.text_encoder_precisions):
raise ValueError(
f"Length of text encoder configs ({len(self.text_encoder_configs)}) must be equal to length of text encoder precisions ({len(self.text_encoder_precisions)})"
)
if len(self.text_encoder_configs) != len(self.preprocess_text_funcs):
raise ValueError(
f"Length of text encoder configs ({len(self.text_encoder_configs)}) must be equal to length of text preprocessing functions ({len(self.preprocess_text_funcs)})"
)
if len(self.preprocess_text_funcs) != len(self.postprocess_text_funcs):
raise ValueError(
f"Length of text postprocess functions ({len(self.postprocess_text_funcs)}) must be equal to length of text preprocessing functions ({len(self.preprocess_text_funcs)})"
)
def dump_to_json(self, file_path: str):
output_dict = shallow_asdict(self)
del_keys = []
for key, value in output_dict.items():
if isinstance(value, ModelConfig):
model_dict = asdict(value)
# Model Arch Config should be hidden away from the users
model_dict.pop("arch_config")
output_dict[key] = model_dict
elif isinstance(value, tuple) and all(
isinstance(v, ModelConfig) for v in value
):
model_dicts = []
for v in value:
model_dict = asdict(v)
# Model Arch Config should be hidden away from the users
model_dict.pop("arch_config")
model_dicts.append(model_dict)
output_dict[key] = model_dicts
elif isinstance(value, tuple) and all(callable(f) for f in value):
# Skip dumping functions
del_keys.append(key)
for key in del_keys:
output_dict.pop(key, None)
with open(file_path, "w") as f:
json.dump(output_dict, f, indent=2)
def load_from_json(self, file_path: str):
with open(file_path) as f:
input_pipeline_dict = json.load(f)
self.update_pipeline_config(input_pipeline_dict)
def update_pipeline_config(self, source_pipeline_dict: dict[str, Any]) -> None:
for f in fields(self):
key = f.name
if key in source_pipeline_dict:
current_value = getattr(self, key)
new_value = source_pipeline_dict[key]
# If it's a nested ModelConfig, update it recursively
if isinstance(current_value, ModelConfig):
current_value.update_model_config(new_value)
elif isinstance(current_value, tuple) and all(
isinstance(v, ModelConfig) for v in current_value
):
assert len(current_value) == len(
new_value
), "Users shouldn't delete or add text encoder config objects in your json"
for target_config, source_config in zip(
current_value, new_value, strict=True
):
target_config.update_model_config(source_config)
else:
setattr(self, key, new_value)
if hasattr(self, "__post_init__"):
self.__post_init__()
@dataclass
class SlidingTileAttnConfig(PipelineConfig):
"""Configuration for sliding tile attention."""
# Override any BaseConfig defaults as needed
# Add sliding tile specific parameters
window_size: int = 16
stride: int = 8
# You can provide custom defaults for inherited fields
height: int = 576
width: int = 1024
# Additional configuration specific to sliding tile attention
pad_to_square: bool = False
use_overlap_optimization: bool = True
def parse_int_list(value: str) -> list[int]:
"""Parse a comma-separated string of integers into a list."""
if not value:
return []
return [int(x.strip()) for x in value.split(",")]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from dataclasses import dataclass, field
from typing import Callable
import torch
from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig
from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig
from sglang.multimodal_gen.configs.models.encoders import (
BaseEncoderOutput,
CLIPTextConfig,
T5Config,
)
from sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig
from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig, preprocess_text
from sglang.multimodal_gen.configs.pipelines.hunyuan import (
clip_postprocess_text,
clip_preprocess_text,
)
def t5_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:
return outputs.last_hidden_state
@dataclass
class FluxPipelineConfig(PipelineConfig):
# FIXME: duplicate with SamplingParams.guidance_scale?
embedded_cfg_scale: float = 3.5
is_image_gen: bool = True
vae_tiling: bool = False
vae_sp: bool = False
dit_config: DiTConfig = field(default_factory=FluxConfig)
# VAE
vae_config: VAEConfig = field(default_factory=FluxVAEConfig)
# Text encoding stage
text_encoder_configs: tuple[EncoderConfig, ...] = field(
default_factory=lambda: (CLIPTextConfig(), T5Config())
)
text_encoder_precisions: tuple[str, ...] = field(
default_factory=lambda: ("bf16", "bf16")
)
preprocess_text_funcs: tuple[Callable[[str], str], ...] = field(
default_factory=lambda: (clip_preprocess_text, preprocess_text),
)
postprocess_text_funcs: tuple[Callable[[str], str], ...] = field(
default_factory=lambda: (clip_postprocess_text, t5_postprocess_text)
)
text_encoder_extra_args: list[dict] = field(
default_factory=lambda: [
dict(
max_length=77,
padding="max_length",
truncation=True,
return_overflowing_tokens=False,
return_length=False,
),
None,
]
)
def prepare_latent_shape(self, batch, batch_size, num_frames):
height = 2 * (
batch.height // (self.vae_config.arch_config.vae_scale_factor * 2)
)
width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2))
num_channels_latents = self.dit_config.arch_config.in_channels // 4
shape = (batch_size, num_channels_latents, height, width)
return shape
def pack_latents(self, latents, batch_size, batch):
height = 2 * (
batch.height // (self.vae_config.arch_config.vae_scale_factor * 2)
)
width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2))
num_channels_latents = self.dit_config.arch_config.in_channels // 4
# pack latents
latents = latents.view(
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
)
return latents
def get_pos_prompt_embeds(self, batch):
return batch.prompt_embeds[1]
def get_neg_prompt_embeds(self, batch):
return batch.negative_prompt_embeds[1]
def _prepare_latent_image_ids(self, original_height, original_width, device):
vae_scale_factor = self.vae_config.arch_config.vae_scale_factor
height = int(original_height) // (vae_scale_factor * 2)
width = int(original_width) // (vae_scale_factor * 2)
latent_image_ids = torch.zeros(height, width, 3, device=device)
latent_image_ids[..., 1] = (
latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None]
)
latent_image_ids[..., 2] = (
latent_image_ids[..., 2] + torch.arange(width, device=device)[None, :]
)
latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
latent_image_ids.shape
)
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids
def get_freqs_cis(self, prompt_embeds, width, height, device, rotary_emb):
txt_ids = torch.zeros(prompt_embeds.shape[1], 3, device=device)
img_ids = self._prepare_latent_image_ids(
original_height=height,
original_width=width,
device=device,
)
ids = torch.cat([txt_ids, img_ids], dim=0).to(device=device)
# NOTE(mick): prepare it here, to avoid unnecessary computations
freqs_cis = rotary_emb.forward(ids)
return freqs_cis
def post_denoising_loop(self, latents, batch):
# unpack latents for flux
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
batch_size = latents.shape[0]
channels = latents.shape[-1]
vae_scale_factor = self.vae_config.arch_config.vae_scale_factor
height = 2 * (int(batch.height) // (vae_scale_factor * 2))
width = 2 * (int(batch.width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents
def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype):
return {
"freqs_cis": self.get_freqs_cis(
batch.prompt_embeds[1], batch.width, batch.height, device, rotary_emb
),
"pooled_projections": (
batch.pooled_embeds[0] if batch.pooled_embeds else None
),
}
def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):
return {
"freqs_cis": self.get_freqs_cis(
batch.negative_prompt_embeds[1],
batch.width,
batch.height,
device,
rotary_emb,
),
"pooled_projections": (
batch.neg_pooled_embeds[0] if batch.neg_pooled_embeds else None
),
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment