Unverified Commit 7bc1dae0 authored by Mick's avatar Mick Committed by GitHub
Browse files

WIP: initial multimodal-gen support (#12484)


Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarJiLi <leege233@gmail.com>
Co-authored-by: default avatarCHEN Xi <78632976+RubiaCx@users.noreply.github.com>
Co-authored-by: default avatarlaixin <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarSolitaryThinker <wlsaidhi@gmail.com>
Co-authored-by: default avatarjzhang38 <a1286225768@gmail.com>
Co-authored-by: default avatarBrianChen1129 <yongqichcd@gmail.com>
Co-authored-by: default avatarKevin Lin <42618777+kevin314@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarrlsu9 <r3su@ucsd.edu>
Co-authored-by: default avatarJinzhe Pan <48981407+eigensystem@users.noreply.github.com>
Co-authored-by: default avatarforeverpiano <pianoqwz@qq.com>
Co-authored-by: default avatarRandNMR73 <notomatthew31@gmail.com>
Co-authored-by: default avatarPorridgeSwim <yz3883@columbia.edu>
Co-authored-by: default avatarJiali Chen <90408393+gary-chenjl@users.noreply.github.com>
parent 4fe53e58
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TypedDict
import torch
from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig
from sglang.multimodal_gen.configs.models.dits import HunyuanVideoConfig
from sglang.multimodal_gen.configs.models.encoders import (
BaseEncoderOutput,
CLIPTextConfig,
LlamaConfig,
)
from sglang.multimodal_gen.configs.models.vaes import HunyuanVAEConfig
from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig
PROMPT_TEMPLATE_ENCODE_VIDEO = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
)
class PromptTemplate(TypedDict):
template: str
crop_start: int
prompt_template_video: PromptTemplate = {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
"crop_start": 95,
}
def llama_preprocess_text(prompt: str) -> str:
return prompt_template_video["template"].format(prompt)
def llama_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.tensor:
hidden_state_skip_layer = 2
assert outputs.hidden_states is not None
hidden_states: tuple[torch.Tensor, ...] = outputs.hidden_states
last_hidden_state: torch.tensor = hidden_states[-(hidden_state_skip_layer + 1)]
crop_start = prompt_template_video.get("crop_start", -1)
last_hidden_state = last_hidden_state[:, crop_start:]
return last_hidden_state
def clip_preprocess_text(prompt: str) -> str:
return prompt
def clip_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.tensor:
pooler_output: torch.tensor = outputs.pooler_output
return pooler_output
@dataclass
class HunyuanConfig(PipelineConfig):
"""Base configuration for HunYuan pipeline architecture."""
# HunyuanConfig-specific parameters with defaults
# DiT
dit_config: DiTConfig = field(default_factory=HunyuanVideoConfig)
# VAE
vae_config: VAEConfig = field(default_factory=HunyuanVAEConfig)
# Denoising stage
embedded_cfg_scale: int = 6
flow_shift: int = 7
# Text encoding stage
text_encoder_configs: tuple[EncoderConfig, ...] = field(
default_factory=lambda: (LlamaConfig(), CLIPTextConfig())
)
preprocess_text_funcs: tuple[Callable[[str], str], ...] = field(
default_factory=lambda: (llama_preprocess_text, clip_preprocess_text)
)
postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.tensor], ...] = (
field(default_factory=lambda: (llama_postprocess_text, clip_postprocess_text))
)
# Precision for each component
dit_precision: str = "bf16"
vae_precision: str = "fp16"
text_encoder_precisions: tuple[str, ...] = field(
default_factory=lambda: ("fp16", "fp16")
)
def __post_init__(self):
self.vae_config.load_encoder = False
self.vae_config.load_decoder = True
@dataclass
class FastHunyuanConfig(HunyuanConfig):
"""Configuration specifically optimized for FastHunyuan weights."""
# Override HunyuanConfig defaults
flow_shift: int = 17
# No need to re-specify guidance_scale or embedded_cfg_scale as they
# already have the desired values from HunyuanConfig
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from dataclasses import dataclass, field
from typing import Callable
import torch
from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions
from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig
from sglang.multimodal_gen.configs.models.dits.qwenimage import QwenImageDitConfig
from sglang.multimodal_gen.configs.models.encoders.qwen_image import Qwen2_5VLConfig
from sglang.multimodal_gen.configs.models.vaes.qwenimage import QwenImageVAEConfig
from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig
def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask]
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
return split_result
def qwen_image_preprocess_text(prompt):
prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
template = prompt_template_encode
txt = template.format(prompt)
return txt
def qwen_image_postprocess_text(outputs, _text_inputs, drop_idx=34):
# squeeze the batch dim
hidden_states = outputs.hidden_states[-1]
split_hidden_states = _extract_masked_hidden(
hidden_states, _text_inputs.attention_mask
)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[
torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))])
for u in split_hidden_states
]
)
return prompt_embeds
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
)
return latents
@dataclass
class QwenImagePipelineConfig(PipelineConfig):
should_use_guidance: bool = False
is_image_gen: bool = True
vae_tiling: bool = False
vae_sp: bool = False
dit_config: DiTConfig = field(default_factory=QwenImageDitConfig)
# VAE
vae_config: VAEConfig = field(default_factory=QwenImageVAEConfig)
# Text encoding stage
text_encoder_configs: tuple[EncoderConfig, ...] = field(
default_factory=lambda: (Qwen2_5VLConfig(),)
)
text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("bf16",))
preprocess_text_funcs: tuple[Callable[[str], str], ...] = field(
default_factory=lambda: (qwen_image_preprocess_text,)
)
postprocess_text_funcs: tuple[Callable[[str], str], ...] = field(
default_factory=lambda: (qwen_image_postprocess_text,)
)
text_encoder_extra_args: list[dict] = field(
default_factory=lambda: [
dict(
padding=True,
truncation=True,
),
None,
]
)
def get_vae_scale_factor(self):
return self.vae_config.arch_config.vae_scale_factor
def prepare_latent_shape(self, batch, batch_size, num_frames):
height = 2 * (
batch.height // (self.vae_config.arch_config.vae_scale_factor * 2)
)
width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2))
num_channels_latents = self.dit_config.arch_config.in_channels // 4
shape = (batch_size, num_channels_latents, height, width)
return shape
def pack_latents(self, latents, batch_size, batch):
height = 2 * (
batch.height // (self.vae_config.arch_config.vae_scale_factor * 2)
)
width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2))
num_channels_latents = self.dit_config.arch_config.in_channels // 4
# pack latents
# _pack_latents(latents, batch_size, num_channels_latents, height, width)
latents = latents.view(
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
)
return latents
@staticmethod
def get_freqs_cis(img_shapes, txt_seq_lens, rotary_emb, device, dtype):
img_freqs, txt_freqs = rotary_emb(img_shapes, txt_seq_lens, device=device)
img_cos, img_sin = (
img_freqs.real.to(dtype=dtype),
img_freqs.imag.to(dtype=dtype),
)
txt_cos, txt_sin = (
txt_freqs.real.to(dtype=dtype),
txt_freqs.imag.to(dtype=dtype),
)
return (img_cos, img_sin), (txt_cos, txt_sin)
def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype):
batch_size = batch.latents.shape[0]
vae_scale_factor = self.vae_config.arch_config.vae_scale_factor
img_shapes = [
[
(
1,
batch.height // vae_scale_factor // 2,
batch.width // vae_scale_factor // 2,
)
]
] * batch_size
txt_seq_lens = [batch.prompt_embeds[0].shape[1]]
return {
"img_shapes": img_shapes,
"txt_seq_lens": txt_seq_lens,
"freqs_cis": QwenImagePipelineConfig.get_freqs_cis(
img_shapes, txt_seq_lens, rotary_emb, device, dtype
),
}
def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):
batch_size = batch.latents.shape[0]
vae_scale_factor = self.vae_config.arch_config.vae_scale_factor
img_shapes = [
[
(
1,
batch.height // vae_scale_factor // 2,
batch.width // vae_scale_factor // 2,
)
]
] * batch_size
txt_seq_lens = [batch.negative_prompt_embeds[0].shape[1]]
return {
"img_shapes": img_shapes,
"txt_seq_lens": txt_seq_lens,
"freqs_cis": QwenImagePipelineConfig.get_freqs_cis(
img_shapes, txt_seq_lens, rotary_emb, device, dtype
),
}
def post_denoising_loop(self, latents, batch):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
batch_size = latents.shape[0]
channels = latents.shape[-1]
vae_scale_factor = self.vae_config.arch_config.vae_scale_factor
height = 2 * (int(batch.height) // (vae_scale_factor * 2))
width = 2 * (int(batch.width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
return latents
class QwenImageEditPipelineConfig(QwenImagePipelineConfig):
ti2i_task = True
def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype):
# TODO: lots of duplications here
batch_size = batch.latents.shape[0]
height = batch.height
width = batch.width
image = batch.pil_image
image_size = image[0].size if isinstance(image, list) else image.size
calculated_width, calculated_height, _ = calculate_dimensions(
1024 * 1024, image_size[0] / image_size[1]
)
vae_scale_factor = self.get_vae_scale_factor()
img_shapes = [
[
(1, height // vae_scale_factor // 2, width // vae_scale_factor // 2),
(
1,
calculated_height // vae_scale_factor // 2,
calculated_width // vae_scale_factor // 2,
),
]
] * batch_size
txt_seq_lens = [batch.prompt_embeds[0].shape[1]]
return {
"img_shapes": img_shapes,
"txt_seq_lens": txt_seq_lens,
"freqs_cis": QwenImagePipelineConfig.get_freqs_cis(
img_shapes, txt_seq_lens, rotary_emb, device, dtype
),
}
def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype):
batch_size = batch.latents.shape[0]
height = batch.height
width = batch.width
image = batch.pil_image
image_size = image[0].size if isinstance(image, list) else image.size
calculated_width, calculated_height, _ = calculate_dimensions(
1024 * 1024, image_size[0] / image_size[1]
)
vae_scale_factor = self.get_vae_scale_factor()
img_shapes = [
[
(1, height // vae_scale_factor // 2, width // vae_scale_factor // 2),
(
1,
calculated_height // vae_scale_factor // 2,
calculated_width // vae_scale_factor // 2,
),
]
] * batch_size
txt_seq_lens = [batch.negative_prompt_embeds[0].shape[1]]
return {
"img_shapes": img_shapes,
"txt_seq_lens": txt_seq_lens,
"freqs_cis": QwenImagePipelineConfig.get_freqs_cis(
img_shapes, txt_seq_lens, rotary_emb, device, dtype
),
}
def prepare_latent_shape(self, batch, batch_size, num_frames):
vae_scale_factor = self.vae_config.arch_config.vae_scale_factor
height = 2 * (batch.height // (vae_scale_factor * 2))
width = 2 * (batch.width // (vae_scale_factor * 2))
num_channels_latents = self.dit_config.arch_config.in_channels // 4
shape = (batch_size, 1, num_channels_latents, height, width)
return shape
def preprocess_image(self, image, image_processor):
image_size = image[0].size if isinstance(image, list) else image.size
calculated_width, calculated_height, _ = calculate_dimensions(
1024 * 1024, image_size[0] / image_size[1]
)
image = image_processor.resize(image, calculated_height, calculated_width)
return image
def set_width_and_height(self, width, height, image):
image_size = image[0].size if isinstance(image, list) else image.size
calculated_width, calculated_height, _ = calculate_dimensions(
1024 * 1024, image_size[0] / image_size[1]
)
height = height or calculated_height
width = width or calculated_width
multiple_of = self.get_vae_scale_factor() * 2
width = width // multiple_of * multiple_of
height = height // multiple_of * multiple_of
return width, height
def slice_noise_pred(self, noise, latents):
noise = noise[:, : latents.size(1)]
return noise
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""Registry for pipeline weight-specific configurations."""
import os
from collections.abc import Callable
from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig
from sglang.multimodal_gen.configs.pipelines.flux import FluxPipelineConfig
from sglang.multimodal_gen.configs.pipelines.hunyuan import (
FastHunyuanConfig,
HunyuanConfig,
)
from sglang.multimodal_gen.configs.pipelines.qwen_image import (
QwenImageEditPipelineConfig,
QwenImagePipelineConfig,
)
from sglang.multimodal_gen.configs.pipelines.stepvideo import StepVideoT2VConfig
# isort: off
from sglang.multimodal_gen.configs.pipelines.wan import (
FastWan2_1_T2V_480P_Config,
FastWan2_2_TI2V_5B_Config,
Wan2_2_I2V_A14B_Config,
Wan2_2_T2V_A14B_Config,
Wan2_2_TI2V_5B_Config,
WanI2V480PConfig,
WanI2V720PConfig,
WanT2V480PConfig,
WanT2V720PConfig,
SelfForcingWanT2V480PConfig,
)
from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (
verify_model_config_and_directory,
maybe_download_model_index,
)
# isort: on
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
# Registry maps specific model weights to their config classes
PIPE_NAME_TO_CONFIG: dict[str, type[PipelineConfig]] = {
"FastVideo/FastHunyuan-diffusers": FastHunyuanConfig,
"hunyuanvideo-community/HunyuanVideo": HunyuanConfig,
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers": WanT2V480PConfig,
"weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers": WanI2V480PConfig,
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers": WanI2V480PConfig,
"Wan-AI/Wan2.1-I2V-14B-720P-Diffusers": WanI2V720PConfig,
"Wan-AI/Wan2.1-T2V-14B-Diffusers": WanT2V720PConfig,
"FastVideo/FastWan2.1-T2V-1.3B-Diffusers": FastWan2_1_T2V_480P_Config,
"FastVideo/FastWan2.1-T2V-14B-480P-Diffusers": FastWan2_1_T2V_480P_Config,
"FastVideo/FastWan2.2-TI2V-5B-Diffusers": FastWan2_2_TI2V_5B_Config,
"FastVideo/stepvideo-t2v-diffusers": StepVideoT2VConfig,
"FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers": WanT2V720PConfig,
"wlsaidhi/SFWan2.1-T2V-1.3B-Diffusers": SelfForcingWanT2V480PConfig,
"Wan-AI/Wan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_Config,
"Wan-AI/Wan2.2-T2V-A14B-Diffusers": Wan2_2_T2V_A14B_Config,
"Wan-AI/Wan2.2-I2V-A14B-Diffusers": Wan2_2_I2V_A14B_Config,
# Add other specific weight variants
"black-forest-labs/FLUX.1-dev": FluxPipelineConfig,
"Qwen/Qwen-Image": QwenImagePipelineConfig,
"Qwen/Qwen-Image-Edit": QwenImageEditPipelineConfig,
}
# For determining pipeline type from model ID
PIPELINE_DETECTOR: dict[str, Callable[[str], bool]] = {
"hunyuan": lambda id: "hunyuan" in id.lower(),
"wanpipeline": lambda id: "wanpipeline" in id.lower(),
"wanimagetovideo": lambda id: "wanimagetovideo" in id.lower(),
"wandmdpipeline": lambda id: "wandmdpipeline" in id.lower(),
"wancausaldmdpipeline": lambda id: "wancausaldmdpipeline" in id.lower(),
"stepvideo": lambda id: "stepvideo" in id.lower(),
"qwenimage": lambda id: "qwen-image" in id.lower() and "edit" not in id.lower(),
"qwenimageedit": lambda id: "qwen-image-edit" in id.lower(),
# Add other pipeline architecture detectors
}
# Fallback configs when exact match isn't found but architecture is detected
PIPELINE_FALLBACK_CONFIG: dict[str, type[PipelineConfig]] = {
"hunyuan": HunyuanConfig, # Base Hunyuan config as fallback for any Hunyuan variant
"wanpipeline": WanT2V480PConfig, # Base Wan config as fallback for any Wan variant
"wanimagetovideo": WanI2V480PConfig,
"wandmdpipeline": FastWan2_1_T2V_480P_Config,
"wancausaldmdpipeline": SelfForcingWanT2V480PConfig,
"stepvideo": StepVideoT2VConfig,
"qwenimage": QwenImagePipelineConfig,
"qwenimageedit": QwenImageEditPipelineConfig,
# Other fallbacks by architecture
}
def get_pipeline_config_cls_from_name(
pipeline_name_or_path: str,
) -> type[PipelineConfig]:
"""Get the appropriate configuration class for a given pipeline name or path.
This function implements a multi-step lookup process to find the most suitable
configuration class for a given pipeline. It follows this order:
1. Exact match in the PIPE_NAME_TO_CONFIG
2. Partial match in the PIPE_NAME_TO_CONFIG
3. Fallback to class name in the model_index.json
4. else raise an error
Args:
pipeline_name_or_path (str): The name or path of the pipeline. This can be:
- A registered model ID (e.g., "FastVideo/FastHunyuan-diffusers")
- A local path to a model directory
- A model ID that will be downloaded
Returns:
Type[PipelineConfig]: The configuration class that best matches the pipeline.
This will be one of:
- A specific weight configuration class if an exact match is found
- A fallback configuration class based on the pipeline architecture
- The base PipelineConfig class if no matches are found
Note:
- For local paths, the function will verify the model configuration
- For remote models, it will attempt to download the model index
- Warning messages are logged when falling back to less specific configurations
"""
pipeline_config_cls: type[PipelineConfig] | None = None
# First try exact match for specific weights
if pipeline_name_or_path in PIPE_NAME_TO_CONFIG:
pipeline_config_cls = PIPE_NAME_TO_CONFIG[pipeline_name_or_path]
if pipeline_config_cls is None:
# Try partial matches (for local paths that might include the weight ID)
for registered_id, config_class in PIPE_NAME_TO_CONFIG.items():
if registered_id in pipeline_name_or_path:
pipeline_config_cls = config_class
break
# If no match, try to use the fallback config
if pipeline_config_cls is None:
if os.path.exists(pipeline_name_or_path):
config = verify_model_config_and_directory(pipeline_name_or_path)
else:
config = maybe_download_model_index(pipeline_name_or_path)
logger.warning(
"Trying to use the config from the model_index.json. sgl-diffusion may not correctly identify the optimal config for this model in this situation."
)
pipeline_name = config["_class_name"]
# Try to determine pipeline architecture for fallback
for pipeline_type, detector in PIPELINE_DETECTOR.items():
if detector(pipeline_name.lower()):
pipeline_config_cls = PIPELINE_FALLBACK_CONFIG.get(pipeline_type)
break
if pipeline_config_cls is not None:
logger.warning(
"No match found for pipeline %s, using fallback config %s.",
pipeline_name_or_path,
pipeline_config_cls,
)
if pipeline_config_cls is None:
raise ValueError(
f"No match found for pipeline {pipeline_name_or_path}, please check the pipeline name or path."
)
return pipeline_config_cls
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models import DiTConfig, VAEConfig
from sglang.multimodal_gen.configs.models.dits import StepVideoConfig
from sglang.multimodal_gen.configs.models.vaes import StepVideoVAEConfig
from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig
@dataclass
class StepVideoT2VConfig(PipelineConfig):
"""Base configuration for StepVideo pipeline architecture."""
# WanConfig-specific parameters with defaults
# DiT
dit_config: DiTConfig = field(default_factory=StepVideoConfig)
# VAE
vae_config: VAEConfig = field(default_factory=StepVideoVAEConfig)
vae_tiling: bool = False
vae_sp: bool = False
# Denoising stage
flow_shift: int = 13
timesteps_scale: bool = False
pos_magic: str = (
"超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。"
)
neg_magic: str = (
"画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。"
)
# Precision for each component
precision: str = "bf16"
vae_precision: str = "bf16"
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Callable
from dataclasses import dataclass, field
import torch
from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig
from sglang.multimodal_gen.configs.models.dits import WanVideoConfig
from sglang.multimodal_gen.configs.models.encoders import (
BaseEncoderOutput,
CLIPVisionConfig,
T5Config,
)
from sglang.multimodal_gen.configs.models.vaes import WanVAEConfig
from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig
def t5_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor:
mask: torch.Tensor = outputs.attention_mask
hidden_state: torch.Tensor = outputs.last_hidden_state
seq_lens = mask.gt(0).sum(dim=1).long()
assert torch.isnan(hidden_state).sum() == 0
prompt_embeds = [u[:v] for u, v in zip(hidden_state, seq_lens, strict=True)]
prompt_embeds_tensor: torch.Tensor = torch.stack(
[
torch.cat([u, u.new_zeros(512 - u.size(0), u.size(1))])
for u in prompt_embeds
],
dim=0,
)
return prompt_embeds_tensor
@dataclass
class WanT2V480PConfig(PipelineConfig):
"""Base configuration for Wan T2V 1.3B pipeline architecture."""
# WanConfig-specific parameters with defaults
# DiT
dit_config: DiTConfig = field(default_factory=WanVideoConfig)
# VAE
vae_config: VAEConfig = field(default_factory=WanVAEConfig)
vae_tiling: bool = False
vae_sp: bool = False
# Denoising stage
flow_shift: float | None = 3.0
# Text encoding stage
text_encoder_configs: tuple[EncoderConfig, ...] = field(
default_factory=lambda: (T5Config(),)
)
postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor], ...] = (
field(default_factory=lambda: (t5_postprocess_text,))
)
# Precision for each component
precision: str = "bf16"
vae_precision: str = "fp32"
text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32",))
# WanConfig-specific added parameters
def __post_init__(self):
self.vae_config.load_encoder = False
self.vae_config.load_decoder = True
@dataclass
class WanT2V720PConfig(WanT2V480PConfig):
"""Base configuration for Wan T2V 14B 720P pipeline architecture."""
# WanConfig-specific parameters with defaults
# Denoising stage
flow_shift: float | None = 5.0
@dataclass
class WanI2V480PConfig(WanT2V480PConfig):
"""Base configuration for Wan I2V 14B 480P pipeline architecture."""
# WanConfig-specific parameters with defaults
i2v_task: bool = True
# Precision for each component
image_encoder_config: EncoderConfig = field(default_factory=CLIPVisionConfig)
image_encoder_precision: str = "fp32"
image_encoder_extra_args: dict = field(
default_factory=lambda: dict(
output_hidden_states=True,
)
)
def postprocess_image(self, image):
return image.hidden_states[-2]
def __post_init__(self) -> None:
self.vae_config.load_encoder = True
self.vae_config.load_decoder = True
@dataclass
class WanI2V720PConfig(WanI2V480PConfig):
"""Base configuration for Wan I2V 14B 720P pipeline architecture."""
# WanConfig-specific parameters with defaults
# Denoising stage
flow_shift: float | None = 5.0
@dataclass
class FastWan2_1_T2V_480P_Config(WanT2V480PConfig):
"""Base configuration for FastWan T2V 1.3B 480P pipeline architecture with DMD"""
# WanConfig-specific parameters with defaults
# Denoising stage
flow_shift: float | None = 8.0
dmd_denoising_steps: list[int] | None = field(
default_factory=lambda: [1000, 757, 522]
)
@dataclass
class Wan2_2_TI2V_5B_Config(WanT2V480PConfig):
flow_shift: float | None = 5.0
ti2v_task: bool = True
expand_timesteps: bool = True
# ti2v, 5B
vae_stride = (4, 16, 16)
def prepare_latent_shape(self, batch, batch_size, num_frames):
F = num_frames
z_dim = self.vae_config.arch_config.z_dim
vae_stride = self.vae_stride
oh = batch.height
ow = batch.width
shape = (z_dim, F, oh // vae_stride[1], ow // vae_stride[2])
return shape
def __post_init__(self) -> None:
self.vae_config.load_encoder = True
self.vae_config.load_decoder = True
self.dit_config.expand_timesteps = self.expand_timesteps
@dataclass
class FastWan2_2_TI2V_5B_Config(Wan2_2_TI2V_5B_Config):
flow_shift: float | None = 5.0
dmd_denoising_steps: list[int] | None = field(
default_factory=lambda: [1000, 757, 522]
)
@dataclass
class Wan2_2_T2V_A14B_Config(WanT2V480PConfig):
flow_shift: float | None = 12.0
boundary_ratio: float | None = 0.875
def __post_init__(self) -> None:
self.dit_config.boundary_ratio = self.boundary_ratio
@dataclass
class Wan2_2_I2V_A14B_Config(WanI2V480PConfig):
flow_shift: float | None = 5.0
boundary_ratio: float | None = 0.900
def __post_init__(self) -> None:
super().__post_init__()
self.dit_config.boundary_ratio = self.boundary_ratio
# =============================================
# ============= Causal Self-Forcing =============
# =============================================
@dataclass
class SelfForcingWanT2V480PConfig(WanT2V480PConfig):
is_causal: bool = True
flow_shift: float | None = 5.0
dmd_denoising_steps: list[int] | None = field(
default_factory=lambda: [1000, 750, 500, 250]
)
warp_denoising_step: bool = True
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from sglang.multimodal_gen.configs.sample.base import SamplingParams
__all__ = ["SamplingParams"]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import argparse
import dataclasses
import hashlib
import json
import os.path
import re
import time
import unicodedata
import uuid
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import align_to
logger = init_logger(__name__)
def _json_safe(obj: Any):
"""
Recursively convert objects to JSON-serializable forms.
- Enums -> their name
- Sets/Tuples -> lists
- Dicts/Lists -> recursively processed
"""
if isinstance(obj, Enum):
return obj.name
if isinstance(obj, dict):
return {k: _json_safe(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple, set)):
return [_json_safe(v) for v in obj]
return obj
def generate_request_id() -> str:
return str(uuid.uuid4())
def _sanitize_filename(name: str, replacement: str = "_", max_length: int = 150) -> str:
"""Create a filesystem- and ffmpeg-friendly filename.
- Normalize to ASCII (drop accents and unsupported chars)
- Replace spaces with underscores
- Replace any char not in [A-Za-z0-9_.-] with replacement
- Collapse multiple underscores
- Trim leading/trailing dots/underscores and limit length
"""
normalized = unicodedata.normalize("NFKD", name)
ascii_name = normalized.encode("ascii", "ignore").decode("ascii")
ascii_name = ascii_name.replace(" ", "_")
ascii_name = re.sub(r"[^A-Za-z0-9._-]", replacement, ascii_name)
ascii_name = re.sub(r"_+", "_", ascii_name).strip("._")
if not ascii_name:
ascii_name = "output"
if max_length and len(ascii_name) > max_length:
ascii_name = ascii_name[:max_length]
return ascii_name
class DataType(Enum):
IMAGE = auto()
VIDEO = auto()
def get_default_extension(self) -> str:
if self == DataType.IMAGE:
return "jpg"
else:
return "mp4"
@dataclass
class SamplingParams:
"""
Sampling parameters for generation.
"""
data_type: DataType = DataType.VIDEO
request_id: str | None = None
# All fields below are copied from ForwardBatch
# Image inputs
image_path: str | None = None
# Text inputs
prompt: str | list[str] | None = None
negative_prompt: str = (
"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
)
prompt_path: str | None = None
output_path: str = "outputs/"
output_file_name: str | None = None
# Batch info
num_outputs_per_prompt: int = 1
seed: int = 1024
# Original dimensions (before VAE scaling)
num_frames: int = 125
num_frames_round_down: bool = (
False # Whether to round down num_frames if it's not divisible by num_gpus
)
height: int | None = None
width: int | None = None
# NOTE: this is temporary, we need a way to know if width or height is not provided, or do the image resize earlier
height_not_provided: bool = False
width_not_provided: bool = False
fps: int = 24
# Denoising parameters
num_inference_steps: int = 50
guidance_scale: float = 1.0
guidance_rescale: float = 0.0
boundary_ratio: float | None = None
# TeaCache parameters
enable_teacache: bool = False
# Profiling
profile: bool = False
num_profiled_timesteps: int = 2
# Debugging
debug: bool = False
# Misc
save_output: bool = True
return_frames: bool = False
return_trajectory_latents: bool = False # returns all latents for each timestep
return_trajectory_decoded: bool = False # returns decoded latents for each timestep
def set_output_file_ext(self):
# add extension if needed
if not any(
self.output_file_name.endswith(ext)
for ext in [".mp4", ".jpg", ".png", ".webp"]
):
self.output_file_name = (
f"{self.output_file_name}.{self.data_type.get_default_extension()}"
)
def set_output_file_name(self):
# settle output_file_name
if (
self.output_file_name is None
and self.prompt
and isinstance(self.prompt, str)
):
# generate a random filename
# get a hash of current params
params_dict = dataclasses.asdict(self)
# Avoid recursion
params_dict["output_file_name"] = ""
# Convert to a stable JSON string
params_str = json.dumps(_json_safe(params_dict), sort_keys=True)
# Create a hash
hasher = hashlib.sha256()
hasher.update(params_str.encode("utf-8"))
param_hash = hasher.hexdigest()[:8]
timestamp = time.strftime("%Y%m%d-%H%M%S")
base = f"{self.prompt[:100]}_{timestamp}_{param_hash}"
self.output_file_name = base
if self.output_file_name is None:
timestamp = time.strftime("%Y%m%d-%H%M%S")
self.output_file_name = f"output_{timestamp}"
self.output_file_name = _sanitize_filename(self.output_file_name)
# Ensure a proper extension is present
self.set_output_file_ext()
def __post_init__(self) -> None:
assert self.num_frames >= 1
self.data_type = DataType.VIDEO if self.num_frames > 1 else DataType.IMAGE
if self.width is None:
self.width_not_provided = True
self.width = 1280
if self.height is None:
self.height_not_provided = True
self.height = 720
def check_sampling_param(self):
if self.prompt_path and not self.prompt_path.endswith(".txt"):
raise ValueError("prompt_path must be a txt file")
def update(self, source_dict: dict[str, Any]) -> None:
for key, value in source_dict.items():
if hasattr(self, key):
setattr(self, key, value)
else:
logger.exception("%s has no attribute %s", type(self).__name__, key)
self.__post_init__()
@classmethod
def from_pretrained(cls, model_path: str, **kwargs) -> "SamplingParams":
from sglang.multimodal_gen.configs.sample.registry import (
get_sampling_param_cls_for_name,
)
sampling_cls = get_sampling_param_cls_for_name(model_path)
logger.debug(f"Using pretrained SamplingParam: {sampling_cls}")
if sampling_cls is not None:
sampling_params: SamplingParams = sampling_cls(**kwargs)
else:
logger.warning(
"Couldn't find an optimal sampling param for %s. Using the default sampling param.",
model_path,
)
sampling_params = cls(**kwargs)
return sampling_params
def from_user_sampling_params(self, user_params):
sampling_params = deepcopy(self)
sampling_params._merge_with_user_params(user_params)
return sampling_params
@staticmethod
def add_cli_args(parser: Any) -> Any:
"""Add CLI arguments for SamplingParam fields"""
parser.add_argument("--data-type", type=str, nargs="+", default=DataType.VIDEO)
parser.add_argument(
"--num-frames-round-down",
action="store_true",
default=SamplingParams.num_frames_round_down,
)
parser.add_argument(
"--enable-teacache",
action="store_true",
default=SamplingParams.enable_teacache,
)
parser.add_argument(
"--profile",
action="store_true",
default=SamplingParams.profile,
help="Enable torch profiler for denoising stage",
)
parser.add_argument(
"--debug",
action="store_true",
default=SamplingParams.debug,
help="",
)
parser.add_argument(
"--num-profiled-timesteps",
type=int,
default=SamplingParams.num_profiled_timesteps,
help="Number of timesteps to profile after warmup",
)
parser.add_argument(
"--prompt",
type=str,
default=SamplingParams.prompt,
help="Text prompt for generation",
)
parser.add_argument(
"--negative-prompt",
type=str,
default=SamplingParams.negative_prompt,
help="Negative text prompt for generation",
)
parser.add_argument(
"--prompt-path",
type=str,
default=SamplingParams.prompt_path,
help="Path to a text file containing the prompt",
)
parser.add_argument(
"--output-path",
type=str,
default=SamplingParams.output_path,
help="Path to save the generated image/video",
)
parser.add_argument(
"--output-file-name",
type=str,
default=SamplingParams.output_file_name,
help="Name of the output file",
)
parser.add_argument(
"--num-outputs-per-prompt",
type=int,
default=SamplingParams.num_outputs_per_prompt,
help="Number of outputs to generate per prompt",
)
parser.add_argument(
"--seed",
type=int,
default=SamplingParams.seed,
help="Random seed for generation",
)
parser.add_argument(
"--num-frames",
type=int,
default=SamplingParams.num_frames,
help="Number of frames to generate",
)
parser.add_argument(
"--height",
type=int,
default=SamplingParams.height,
help="Height of generated output",
)
parser.add_argument(
"--width",
type=int,
default=SamplingParams.width,
help="Width of generated output",
)
parser.add_argument(
"--fps",
type=int,
default=SamplingParams.fps,
help="Frames per second for saved output",
)
parser.add_argument(
"--num-inference-steps",
type=int,
default=SamplingParams.num_inference_steps,
help="Number of denoising steps",
)
parser.add_argument(
"--guidance-scale",
type=float,
default=SamplingParams.guidance_scale,
help="Classifier-free guidance scale",
)
parser.add_argument(
"--guidance-rescale",
type=float,
default=SamplingParams.guidance_rescale,
help="Guidance rescale factor",
)
parser.add_argument(
"--boundary-ratio",
type=float,
default=SamplingParams.boundary_ratio,
help="Boundary timestep ratio",
)
parser.add_argument(
"--save-output",
action="store_true",
default=SamplingParams.save_output,
help="Whether to save the output to disk",
)
parser.add_argument(
"--no-save-output",
action="store_false",
dest="save_output",
help="Don't save the output to disk",
)
parser.add_argument(
"--return-frames",
action="store_true",
default=SamplingParams.return_frames,
help="Whether to return the raw frames",
)
parser.add_argument(
"--image-path",
type=str,
default=SamplingParams.image_path,
help="Path to input image for image-to-video generation",
)
parser.add_argument(
"--moba-config-path",
type=str,
default=None,
help="Path to a JSON file containing V-MoBA specific configurations.",
)
parser.add_argument(
"--return-trajectory-latents",
action="store_true",
default=SamplingParams.return_trajectory_latents,
help="Whether to return the trajectory",
)
parser.add_argument(
"--return-trajectory-decoded",
action="store_true",
default=SamplingParams.return_trajectory_decoded,
help="Whether to return the decoded trajectory",
)
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
attrs = [attr.name for attr in dataclasses.fields(cls)]
args.height_not_provided = False
args.width_not_provided = False
return cls(**{attr: getattr(args, attr) for attr in attrs})
def output_file_path(self):
return os.path.join(self.output_path, self.output_file_name)
def _merge_with_user_params(self, user_params):
"""
Merges parameters from a user-provided SamplingParams object.
This method updates the current object with values from `user_params`,
but skips any fields that are explicitly defined in the current object's
subclass. This is to preserve model-specific optimal parameters.
It also skips fields that the user has not changed from the default
in `user_params`.
"""
if user_params is None:
return
# Get fields defined directly in the subclass (not inherited)
subclass_defined_fields = set(type(self).__annotations__.keys())
# Compare against current instance to avoid constructing a default instance
default_params = SamplingParams()
for field in dataclasses.fields(user_params):
field_name = field.name
user_value = getattr(user_params, field_name)
default_value = getattr(default_params, field_name)
# A field is considered user-modified if its value is different from
# the default, with an exception for `output_file_name` which is
# auto-generated with a random component.
is_user_modified = (
user_value != default_value
if field_name != "output_file_name"
else user_params.output_file_path is not None
)
if is_user_modified and field_name not in subclass_defined_fields:
if hasattr(self, field_name):
setattr(self, field_name, user_value)
self.__post_init__()
@property
def n_tokens(self) -> int:
# Calculate latent sizes
if self.height and self.width:
latents_size = [
(self.num_frames - 1) // 4 + 1,
self.height // 8,
self.width // 8,
]
n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
else:
n_tokens = -1
return n_tokens
def output_file_path(self):
return os.path.join(self.output_path, self.output_file_name)
def log(self, server_args: ServerArgs):
# TODO: in some cases (e.g., TI2I), height and weight might be undecided at this moment
if self.height:
target_height = align_to(self.height, 16)
else:
target_height = -1
if self.width:
target_width = align_to(self.width, 16)
else:
target_width = -1
# Log sampling parameters
debug_str = f"""Sampling params:
height: {target_height}
width: {target_width}
num_frames: {self.num_frames}
prompt: {self.prompt}
neg_prompt: {self.negative_prompt}
seed: {self.seed}
infer_steps: {self.num_inference_steps}
num_outputs_per_prompt: {self.num_outputs_per_prompt}
guidance_scale: {self.guidance_scale}
embedded_guidance_scale: {server_args.pipeline_config.embedded_cfg_scale}
n_tokens: {self.n_tokens}
flow_shift: {server_args.pipeline_config.flow_shift}
image_path: {self.image_path}
save_output: {self.save_output}
output_file_path: {self.output_file_path()}
""" # type: ignore[attr-defined]
logger.info(debug_str)
@dataclass
class CacheParams:
cache_type: str = "none"
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from sglang.multimodal_gen.configs.sample.base import SamplingParams
@dataclass
class FluxSamplingParams(SamplingParams):
# Video parameters
# height: int = 1024
# width: int = 1024
num_frames: int = 1
# Denoising stage
guidance_scale: float = 1.0
negative_prompt: str = None
num_inference_steps: int = 50
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.sample.base import SamplingParams
from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams
@dataclass
class HunyuanSamplingParams(SamplingParams):
num_inference_steps: int = 50
num_frames: int = 125
height: int = 720
width: int = 1280
fps: int = 24
guidance_scale: float = 1.0
teacache_params: TeaCacheParams = field(
default_factory=lambda: TeaCacheParams(
teacache_thresh=0.15,
coefficients=[
7.33226126e02,
-4.01131952e02,
6.75869174e01,
-3.14987800e00,
9.61237896e-02,
],
)
)
@dataclass
class FastHunyuanSamplingParam(HunyuanSamplingParams):
num_inference_steps: int = 6
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from sglang.multimodal_gen.configs.sample.base import SamplingParams
@dataclass
class QwenImageSamplingParams(SamplingParams):
# Video parameters
# height: int = 1024
# width: int = 1024
negative_prompt: str = " "
num_frames: int = 1
# Denoising stage
guidance_scale: float = 4.0
num_inference_steps: int = 50
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import os
from collections.abc import Callable
from typing import Any
from sglang.multimodal_gen.configs.sample.flux import FluxSamplingParams
from sglang.multimodal_gen.configs.sample.hunyuan import (
FastHunyuanSamplingParam,
HunyuanSamplingParams,
)
from sglang.multimodal_gen.configs.sample.qwenimage import QwenImageSamplingParams
from sglang.multimodal_gen.configs.sample.stepvideo import StepVideoT2VSamplingParams
# isort: off
from sglang.multimodal_gen.configs.sample.wan import (
FastWanT2V480PConfig,
Wan2_1_Fun_1_3B_InP_SamplingParams,
Wan2_2_I2V_A14B_SamplingParam,
Wan2_2_T2V_A14B_SamplingParam,
Wan2_2_TI2V_5B_SamplingParam,
WanI2V_14B_480P_SamplingParam,
WanI2V_14B_720P_SamplingParam,
WanT2V_1_3B_SamplingParams,
WanT2V_14B_SamplingParams,
SelfForcingWanT2V480PConfig,
)
from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (
maybe_download_model_index,
verify_model_config_and_directory,
)
# isort: on
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
# Registry maps specific model weights to their config classes
SAMPLING_PARAM_REGISTRY: dict[str, Any] = {
"FastVideo/FastHunyuan-diffusers": FastHunyuanSamplingParam,
"hunyuanvideo-community/HunyuanVideo": HunyuanSamplingParams,
"FastVideo/stepvideo-t2v-diffusers": StepVideoT2VSamplingParams,
# Wan2.1
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers": WanT2V_1_3B_SamplingParams,
"Wan-AI/Wan2.1-T2V-14B-Diffusers": WanT2V_14B_SamplingParams,
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers": WanI2V_14B_480P_SamplingParam,
"Wan-AI/Wan2.1-I2V-14B-720P-Diffusers": WanI2V_14B_720P_SamplingParam,
"weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers": Wan2_1_Fun_1_3B_InP_SamplingParams,
# Wan2.2
"Wan-AI/Wan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_SamplingParam,
"FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers": Wan2_2_TI2V_5B_SamplingParam,
"Wan-AI/Wan2.2-T2V-A14B-Diffusers": Wan2_2_T2V_A14B_SamplingParam,
"Wan-AI/Wan2.2-I2V-A14B-Diffusers": Wan2_2_I2V_A14B_SamplingParam,
# FastWan2.1
"FastVideo/FastWan2.1-T2V-1.3B-Diffusers": FastWanT2V480PConfig,
# FastWan2.2
"FastVideo/FastWan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_SamplingParam,
# Causal Self-Forcing Wan2.1
"wlsaidhi/SFWan2.1-T2V-1.3B-Diffusers": SelfForcingWanT2V480PConfig,
# Add other specific weight variants
"black-forest-labs/FLUX.1-dev": FluxSamplingParams,
"Qwen/Qwen-Image": QwenImageSamplingParams,
"Qwen/Qwen-Image-Edit": QwenImageSamplingParams,
}
# For determining pipeline type from model ID
SAMPLING_PARAM_DETECTOR: dict[str, Callable[[str], bool]] = {
"hunyuan": lambda id: "hunyuan" in id.lower(),
"wanpipeline": lambda id: "wanpipeline" in id.lower(),
"wanimagetovideo": lambda id: "wanimagetovideo" in id.lower(),
"stepvideo": lambda id: "stepvideo" in id.lower(),
# Add other pipeline architecture detectors
"flux": lambda id: "flux" in id.lower(),
}
# Fallback configs when exact match isn't found but architecture is detected
SAMPLING_FALLBACK_PARAM: dict[str, Any] = {
"hunyuan": HunyuanSamplingParams, # Base Hunyuan config as fallback for any Hunyuan variant
"wanpipeline": WanT2V_1_3B_SamplingParams, # Base Wan config as fallback for any Wan variant
"wanimagetovideo": WanI2V_14B_480P_SamplingParam,
"stepvideo": StepVideoT2VSamplingParams,
# Other fallbacks by architecture
"flux": FluxSamplingParams,
}
def get_sampling_param_cls_for_name(pipeline_name_or_path: str) -> Any | None:
"""Get the appropriate sampling param for specific pretrained weights."""
if os.path.exists(pipeline_name_or_path):
config = verify_model_config_and_directory(pipeline_name_or_path)
logger.warning(
"sgl-diffusion may not correctly identify the optimal sampling param for this model, as the local directory may have been renamed."
)
else:
config = maybe_download_model_index(pipeline_name_or_path)
pipeline_name = config["_class_name"]
# First try exact match for specific weights
if pipeline_name_or_path in SAMPLING_PARAM_REGISTRY:
return SAMPLING_PARAM_REGISTRY[pipeline_name_or_path]
# Try partial matches (for local paths that might include the weight ID)
for registered_id, config_class in SAMPLING_PARAM_REGISTRY.items():
if registered_id in pipeline_name_or_path:
return config_class
# If no match, try to use the fallback config
fallback_config = None
# Try to determine pipeline architecture for fallback
for pipeline_type, detector in SAMPLING_PARAM_DETECTOR.items():
if detector(pipeline_name.lower()):
fallback_config = SAMPLING_FALLBACK_PARAM.get(pipeline_type)
break
logger.warning(
"No match found for pipeline %s, using fallback sampling param %s.",
pipeline_name_or_path,
fallback_config,
)
return fallback_config
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from sglang.multimodal_gen.configs.sample.base import SamplingParams
@dataclass
class StepVideoT2VSamplingParams(SamplingParams):
# Video parameters
height: int = 720
width: int = 1280
num_frames: int = 81
# Denoising stage
guidance_scale: float = 9.0
num_inference_steps: int = 50
# neg magic and pos magic
# pos_magic: str = "超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。"
# neg_magic: str = "画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。"
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.sample.base import CacheParams
@dataclass
class TeaCacheParams(CacheParams):
cache_type: str = "teacache"
teacache_thresh: float = 0.0
coefficients: list[float] = field(default_factory=list)
@dataclass
class WanTeaCacheParams(CacheParams):
# Unfortunately, TeaCache is very different for Wan than other models
cache_type: str = "teacache"
teacache_thresh: float = 0.0
use_ret_steps: bool = True
ret_steps_coeffs: list[float] = field(default_factory=list)
non_ret_steps_coeffs: list[float] = field(default_factory=list)
@property
def coefficients(self) -> list[float]:
if self.use_ret_steps:
return self.ret_steps_coeffs
else:
return self.non_ret_steps_coeffs
@property
def ret_steps(self) -> int:
if self.use_ret_steps:
return 5 * 2
else:
return 1 * 2
def get_cutoff_steps(self, num_inference_steps: int) -> int:
if self.use_ret_steps:
return num_inference_steps * 2
else:
return num_inference_steps * 2 - 2
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.sample.base import SamplingParams
from sglang.multimodal_gen.configs.sample.teacache import WanTeaCacheParams
@dataclass
class WanT2V_1_3B_SamplingParams(SamplingParams):
# Video parameters
height: int = 480
width: int = 832
num_frames: int = 81
fps: int = 16
# Denoising stage
guidance_scale: float = 3.0
negative_prompt: str = (
"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
)
num_inference_steps: int = 50
teacache_params: WanTeaCacheParams = field(
default_factory=lambda: WanTeaCacheParams(
teacache_thresh=0.08,
ret_steps_coeffs=[
-5.21862437e04,
9.23041404e03,
-5.28275948e02,
1.36987616e01,
-4.99875664e-02,
],
non_ret_steps_coeffs=[
2.39676752e03,
-1.31110545e03,
2.01331979e02,
-8.29855975e00,
1.37887774e-01,
],
)
)
@dataclass
class WanT2V_14B_SamplingParams(SamplingParams):
# Video parameters
height: int = 720
width: int = 1280
num_frames: int = 81
fps: int = 16
# Denoising stage
guidance_scale: float = 5.0
negative_prompt: str = (
"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
)
num_inference_steps: int = 50
teacache_params: WanTeaCacheParams = field(
default_factory=lambda: WanTeaCacheParams(
teacache_thresh=0.20,
use_ret_steps=False,
ret_steps_coeffs=[
-3.03318725e05,
4.90537029e04,
-2.65530556e03,
5.87365115e01,
-3.15583525e-01,
],
non_ret_steps_coeffs=[
-5784.54975374,
5449.50911966,
-1811.16591783,
256.27178429,
-13.02252404,
],
)
)
@dataclass
class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParams):
# Denoising stage
guidance_scale: float = 5.0
num_inference_steps: int = 50
# num_inference_steps: int = 40
teacache_params: WanTeaCacheParams = field(
default_factory=lambda: WanTeaCacheParams(
teacache_thresh=0.26,
ret_steps_coeffs=[
-3.03318725e05,
4.90537029e04,
-2.65530556e03,
5.87365115e01,
-3.15583525e-01,
],
non_ret_steps_coeffs=[
-5784.54975374,
5449.50911966,
-1811.16591783,
256.27178429,
-13.02252404,
],
)
)
@dataclass
class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParams):
# Denoising stage
guidance_scale: float = 5.0
num_inference_steps: int = 50
# num_inference_steps: int = 40
teacache_params: WanTeaCacheParams = field(
default_factory=lambda: WanTeaCacheParams(
teacache_thresh=0.3,
ret_steps_coeffs=[
-3.03318725e05,
4.90537029e04,
-2.65530556e03,
5.87365115e01,
-3.15583525e-01,
],
non_ret_steps_coeffs=[
-5784.54975374,
5449.50911966,
-1811.16591783,
256.27178429,
-13.02252404,
],
)
)
@dataclass
class FastWanT2V480PConfig(WanT2V_1_3B_SamplingParams):
# DMD parameters
# dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])
num_inference_steps: int = 3
num_frames: int = 61
height: int = 448
width: int = 832
fps: int = 16
# =============================================
# ============= Wan2.1 Fun Models =============
# =============================================
@dataclass
class Wan2_1_Fun_1_3B_InP_SamplingParams(SamplingParams):
"""Sampling parameters for Wan2.1 Fun 1.3B InP model."""
height: int = 480
width: int = 832
num_frames: int = 81
fps: int = 16
negative_prompt: str | None = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
)
guidance_scale: float = 6.0
num_inference_steps: int = 50
# =============================================
# ============= Wan2.2 TI2V Models =============
# =============================================
@dataclass
class Wan2_2_Base_SamplingParams(SamplingParams):
"""Sampling parameters for Wan2.2 TI2V 5B model."""
negative_prompt: str | None = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
)
@dataclass
class Wan2_2_TI2V_5B_SamplingParam(Wan2_2_Base_SamplingParams):
"""Sampling parameters for Wan2.2 TI2V 5B model."""
height: int = 704
width: int = 1280
num_frames: int = 121
fps: int = 24
guidance_scale: float = 5.0
num_inference_steps: int = 50
@dataclass
class Wan2_2_T2V_A14B_SamplingParam(Wan2_2_Base_SamplingParams):
guidance_scale: float = 4.0 # high_noise
guidance_scale_2: float = 3.0 # low_noise
num_inference_steps: int = 40
fps: int = 16
# NOTE(will): default boundary timestep is tracked by PipelineConfig, but
# can be overridden during sampling
@dataclass
class Wan2_2_I2V_A14B_SamplingParam(Wan2_2_Base_SamplingParams):
guidance_scale: float = 3.5 # high_noise
guidance_scale_2: float = 3.5 # low_noise
num_inference_steps: int = 40
fps: int = 16
# NOTE(will): default boundary timestep is tracked by PipelineConfig, but
# can be overridden during sampling
# =============================================
# ============= Causal Self-Forcing =============
# =============================================
@dataclass
class SelfForcingWanT2V480PConfig(WanT2V_1_3B_SamplingParams):
pass
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import argparse
from typing import Any
def update_config_from_args(
config: Any, args_dict: dict[str, Any], prefix: str = "", pop_args: bool = False
) -> bool:
"""
Update configuration object from arguments dictionary.
Args:
config: The configuration object to update
args_dict: Dictionary containing arguments
prefix: Prefix for the configuration parameters in the args_dict.
If None, assumes direct attribute mapping without prefix.
"""
# Handle top-level attributes (no prefix)
args_not_to_remove = [
"model_path",
]
args_to_remove = []
if prefix.strip() == "":
for key, value in args_dict.items():
if hasattr(config, key) and value is not None:
if key == "text_encoder_precisions" and isinstance(value, list):
setattr(config, key, tuple(value))
else:
setattr(config, key, value)
if pop_args:
args_to_remove.append(key)
else:
# Handle nested attributes with prefix
prefix_with_dot = f"{prefix}."
for key, value in args_dict.items():
if key.startswith(prefix_with_dot) and value is not None:
attr_name = key[len(prefix_with_dot) :]
if hasattr(config, attr_name):
setattr(config, attr_name, value)
if pop_args:
args_to_remove.append(key)
if pop_args:
for key in args_to_remove:
if key not in args_not_to_remove:
args_dict.pop(key)
return len(args_to_remove) > 0
def clean_cli_args(args: argparse.Namespace) -> dict[str, Any]:
"""
Clean the arguments by removing the ones that not explicitly provided by the user.
"""
provided_args = {}
for k, v in vars(args).items():
if v is not None and hasattr(args, "_provided") and k in args._provided:
provided_args[k] = v
return provided_args
{
"embedded_cfg_scale": 6.0,
"flow_shift": 3,
"dit_cpu_offload": true,
"disable_autocast": false,
"precision": "bf16",
"vae_precision": "fp32",
"vae_tiling": false,
"vae_sp": false,
"vae_config": {
"load_encoder": false,
"load_decoder": true,
"tile_sample_min_height": 256,
"tile_sample_min_width": 256,
"tile_sample_min_num_frames": 16,
"tile_sample_stride_height": 192,
"tile_sample_stride_width": 192,
"tile_sample_stride_num_frames": 12,
"blend_num_frames": 8,
"use_tiling": false,
"use_temporal_tiling": false,
"use_parallel_tiling": false,
"use_feature_cache": true
},
"dit_config": {
"prefix": "Wan",
"quant_config": null
},
"text_encoder_precisions": [
"fp32"
],
"text_encoder_configs": [
{
"prefix": "t5",
"quant_config": null,
"lora_config": null
}
],
"mask_strategy_file_path": null,
"enable_torch_compile": false
}
{
"embedded_cfg_scale": 6.0,
"flow_shift": 3,
"dit_cpu_offload": true,
"disable_autocast": false,
"precision": "bf16",
"vae_precision": "fp32",
"vae_tiling": false,
"vae_sp": false,
"vae_config": {
"load_encoder": true,
"load_decoder": true,
"tile_sample_min_height": 256,
"tile_sample_min_width": 256,
"tile_sample_min_num_frames": 16,
"tile_sample_stride_height": 192,
"tile_sample_stride_width": 192,
"tile_sample_stride_num_frames": 12,
"blend_num_frames": 8,
"use_tiling": false,
"use_temporal_tiling": false,
"use_parallel_tiling": false,
"use_feature_cache": true
},
"dit_config": {
"prefix": "Wan",
"quant_config": null
},
"text_encoder_precisions": [
"fp32"
],
"text_encoder_configs": [
{
"prefix": "t5",
"quant_config": null,
"lora_config": null
}
],
"mask_strategy_file_path": null,
"enable_torch_compile": false,
"image_encoder_config": {
"prefix": "clip",
"quant_config": null,
"lora_config": null,
"num_hidden_layers_override": null,
"require_post_norm": null
},
"image_encoder_precision": "fp32"
}
# Attention Kernel Used in sgl-diffusion
## VMoBA: Mixture-of-Block Attention for Video Diffusion Models (VMoBA)
### Installation
Please ensure that you have installed FlashAttention version **2.7.1 or higher**, as some interfaces have changed in recent releases.
### Usage
You can use `moba_attn_varlen` in the following ways:
**Install from source:**
```bash
python setup.py install
```
**Import after installation:**
```python
from vmoba import moba_attn_varlen
```
**Or import directly from the project root:**
```python
from csrc.attn.vmoba_attn.vmoba import moba_attn_varlen
```
### Verify if you have successfully installed
```bash
python csrc/attn/vmoba_attn/vmoba/vmoba.py
```
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment