"vscode:/vscode.git/clone" did not exist on "8b6966d0205abeaca143693c6f273dcacbfa779d"
Unverified Commit 7bc1dae0 authored by Mick's avatar Mick Committed by GitHub
Browse files

WIP: initial multimodal-gen support (#12484)


Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarJiLi <leege233@gmail.com>
Co-authored-by: default avatarCHEN Xi <78632976+RubiaCx@users.noreply.github.com>
Co-authored-by: default avatarlaixin <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarSolitaryThinker <wlsaidhi@gmail.com>
Co-authored-by: default avatarjzhang38 <a1286225768@gmail.com>
Co-authored-by: default avatarBrianChen1129 <yongqichcd@gmail.com>
Co-authored-by: default avatarKevin Lin <42618777+kevin314@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarrlsu9 <r3su@ucsd.edu>
Co-authored-by: default avatarJinzhe Pan <48981407+eigensystem@users.noreply.github.com>
Co-authored-by: default avatarforeverpiano <pianoqwz@qq.com>
Co-authored-by: default avatarRandNMR73 <notomatthew31@gmail.com>
Co-authored-by: default avatarPorridgeSwim <yz3883@columbia.edu>
Co-authored-by: default avatarJiali Chen <90408393+gary-chenjl@users.noreply.github.com>
parent 4fe53e58
default_stages: [pre-commit, pre-push, manual]
exclude: ^python/sglang/multimodal_gen/csrc
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
......@@ -31,7 +32,15 @@ repos:
- --select=F401,F821
- --fix
files: ^(benchmark/|docs/|examples/|python/sglang/|sgl-router/py_*)
exclude: __init__\.py$|\.ipynb$|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$
exclude: |
(?x)^(
.*/__init__\.py$|
.*\.ipynb$|
python/sglang/srt/grpc/.*_pb2\.py$|
python/sglang/srt/grpc/.*_pb2_grpc\.py$|
python/sglang/srt/grpc/.*_pb2\.pyi$|
python/sglang/srt/grpc/.*_pb2_grpc\.pyi$|
)$
- repo: https://github.com/psf/black
rev: 24.10.0
hooks:
......
FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04
ENV DEBIAN_FRONTEND=noninteractive
SHELL ["/bin/bash", "-c"]
WORKDIR /sgl-workspace/sglang
RUN apt-get update && apt-get install -y --no-install-recommends \
wget \
git \
ca-certificates \
openssh-server \
zsh \
vim \
curl \
gcc-11 \
g++-11 \
clang-11 \
libnuma1 libnuma-dev \
&& rm -rf /var/lib/apt/lists/*
# Install oh-my-zsh and plugins
RUN sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)" "" --unattended \
&& git clone https://github.com/zsh-users/zsh-autosuggestions ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-autosuggestions \
&& git clone https://github.com/zsh-users/zsh-syntax-highlighting.git ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-syntax-highlighting
# Set up C++20 compilers for ThunderKittens
RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 100 --slave /usr/bin/g++ g++ /usr/bin/g++-11
# Set CUDA environment variables
ENV CUDA_HOME=/usr/local/cuda-12.8
ENV PATH=${CUDA_HOME}/bin:${PATH}
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:$LD_LIBRARY_PATH
# Install uv and source its environment
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
echo 'source $HOME/.local/bin/env' >> /root/.zshrc
# Copy just the pyproject.toml first to leverage Docker cache
COPY python/pyproject.toml python/
# Create a dummy README to satisfy the installation
RUN mkdir -p python && echo "# Placeholder" > python/README.md
# Create and activate virtual environment with specific Python version and seed
RUN source $HOME/.local/bin/env && \
uv venv --python 3.12 --seed /opt/venv && \
source /opt/venv/bin/activate && \
uv pip install nvitop && \
uv pip install --no-cache-dir --upgrade pip && \
uv pip install --no-cache-dir --prerelease=allow./python[diffusion]
COPY . .
# Install dependencies using uv and set up shell configuration
RUN source $HOME/.local/bin/env && \
source /opt/venv/bin/activate && \
git config --unset-all http.https://github.com/.extraheader || true && \
echo 'source /opt/venv/bin/activate' >> /root/.zshrc && \
echo 'if [ -n "$ZSH_VERSION" ] && [ -f ~/.zshrc ]; then . ~/.zshrc; elif [ -f ~/.bashrc ]; then . ~/.bashrc; fi' > /root/.profile
# Set PATH to include venv bin
ENV PATH=/opt/venv/bin:$PATH
# Configure zsh
COPY --chown=root:root <<-"EOF" /root/.zshrc
export ZSH="/root/.oh-my-zsh"
source $HOME/.local/bin/env
source /opt/venv/bin/activate
## Theme
ZSH_THEME="robbyrussell"
## Plugins
plugins=(
git
z
zsh-autosuggestions
zsh-syntax-highlighting
)
source $ZSH/oh-my-zsh.sh
## Aliases
alias ll='ls -alF'
alias la='ls -A'
alias l='ls -CF'
alias vi='vim'
## Enhanced history
HISTSIZE=10000
SAVEHIST=10000
setopt HIST_IGNORE_ALL_DUPS
setopt HIST_FIND_NO_DUPS
setopt INC_APPEND_HISTORY
EOF
EXPOSE 22
CMD ["/bin/zsh"]
......@@ -79,6 +79,25 @@ dependencies = [
[project.optional-dependencies]
checkpoint-engine = ["checkpoint-engine==0.1.2"]
diffusion = [
"diffusers==0.35.2",
"yunchang==0.6.3.post1",
"opencv-python==4.10.0.84",
"imageio==2.36.0",
"imageio-ffmpeg==0.5.1",
"PyYAML==6.0.1",
"moviepy>=2.0.0",
"cloudpickle",
"remote-pdb",
"torchcodec==0.5.0",
"st_attn ==0.0.7",
"vsa==0.0.4",
]
[tool.uv.extra-build-dependencies]
st-attn = ["torch", "setuptools"]
vsa = ["torch", "setuptools"]
test = [
"accelerate",
"expecttest",
......@@ -102,6 +121,9 @@ tracing = [
"Homepage" = "https://github.com/sgl-project/sglang"
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[project.scripts]
sglang = "sglang.cli.main:main"
[tool.setuptools.package-data]
"sglang" = [
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
......
import argparse
from sglang.cli.main import get_is_diffusion_model, get_model_path
from sglang.multimodal_gen.runtime.entrypoints.cli.generate import (
add_multimodal_gen_generate_args,
generate_cmd,
)
def generate(args, extra_argv):
model_path = get_model_path(extra_argv)
is_diffusion_model = get_is_diffusion_model(model_path)
if is_diffusion_model:
parser = argparse.ArgumentParser(description="SGLang Multimodal Generation")
add_multimodal_gen_generate_args(parser)
parsed_args = parser.parse_args(extra_argv)
generate_cmd(parsed_args)
else:
raise Exception(
f"Generate subcommand is not yet supported for model: {model_path}"
)
import argparse
import hashlib
import json
import logging
import os
import tempfile
from typing import Optional
import filelock
from huggingface_hub import hf_hub_download
from sglang.cli.generate import generate
from sglang.cli.serve import serve
logger = logging.getLogger(__name__)
temp_dir = tempfile.gettempdir()
def _get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
lock_dir = cache_dir or temp_dir
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
model_name = model_name_or_path.replace("/", "-")
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
lock_file_name = hash_name + model_name + ".lock"
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
return lock
# Copied and adapted from hf_diffusers_utils.py
def _maybe_download_model(
model_name_or_path: str, local_dir: str | None = None, download: bool = True
) -> str:
"""
Resolve a model path. If it's a local directory, return it.
If it's a Hugging Face Hub ID, download only the config file
(`model_index.json` or `config.json`) and return its directory.
Args:
model_name_or_path: Local path or Hugging Face Hub model ID
local_dir: Local directory to save the downloaded file (if any)
download: Whether to download from Hugging Face Hub when needed
Returns:
Local directory path that contains the downloaded config file, or the original local directory.
"""
if os.path.exists(model_name_or_path):
logger.info("Model already exists locally")
return model_name_or_path
if not download:
return model_name_or_path
with _get_lock(model_name_or_path):
# Try `model_index.json` first (diffusers models)
try:
logger.info(
"Downloading model_index.json from HF Hub for %s...",
model_name_or_path,
)
file_path = hf_hub_download(
repo_id=model_name_or_path,
filename="model_index.json",
local_dir=local_dir,
)
logger.info("Downloaded to %s", file_path)
return os.path.dirname(file_path)
except Exception as e_index:
logger.debug("model_index.json not found or failed: %s", e_index)
# Fallback to `config.json`
try:
logger.info(
"Downloading config.json from HF Hub for %s...", model_name_or_path
)
file_path = hf_hub_download(
repo_id=model_name_or_path,
filename="config.json",
local_dir=local_dir,
)
logger.info("Downloaded to %s", file_path)
return os.path.dirname(file_path)
except Exception as e_config:
raise ValueError(
(
"Could not find model locally at %s and failed to download "
"model_index.json/config.json from HF Hub: %s"
)
% (model_name_or_path, e_config)
) from e_config
# Copied and adapted from hf_diffusers_utils.py
def is_diffusers_model_path(model_path: str) -> True:
"""
Verify if the model directory contains a valid diffusers configuration.
Args:
model_path: Path to the model directory
Returns:
The loaded model configuration as a dictionary if the model is a diffusers model
None if the model is not a diffusers model
"""
# Prefer model_index.json which indicates a diffusers pipeline
config_path = os.path.join(model_path, "model_index.json")
if not os.path.exists(config_path):
return False
# Load the config
with open(config_path) as f:
config = json.load(f)
# Verify diffusers version exists
if "_diffusers_version" not in config:
return False
return True
def get_is_diffusion_model(model_path: str):
model_path = _maybe_download_model(model_path)
is_diffusion_model = is_diffusers_model_path(model_path)
if is_diffusion_model:
logger.info("Diffusion model detected")
return is_diffusion_model
def get_model_path(extra_argv):
# Find the model_path argument
model_path = None
for i, arg in enumerate(extra_argv):
if arg == "--model-path":
if i + 1 < len(extra_argv):
model_path = extra_argv[i + 1]
break
elif arg.startswith("--model-path="):
model_path = arg.split("=", 1)[1]
break
if model_path is None:
# Fallback for --help or other cases where model-path is not provided
if any(h in extra_argv for h in ["-h", "--help"]):
raise Exception(
"Usage: sglang serve --model-path <model-name-or-path> [additional-arguments]\n\n"
"This command can launch either a standard language model server or a diffusion model server.\n"
"The server type is determined by the model path.\n"
"For specific arguments, please provide a model_path."
)
else:
raise Exception(
"Error: --model-path is required. "
"Please provide the path to the model."
)
return model_path
def main():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="subcommand", required=True)
serve_parser = subparsers.add_parser(
"serve",
help="Launch the SGLang server.",
add_help=False, # Defer help to the specific parser
)
serve_parser.set_defaults(func=serve)
generate_parser = subparsers.add_parser(
"generate",
help="Run inference on a multimodal model.",
add_help=False, # Defer help to the specific parser
)
generate_parser.set_defaults(func=generate)
args, extra_argv = parser.parse_known_args()
args.func(args, extra_argv)
# SPDX-License-Identifier: Apache-2.0
import argparse
import logging
import os
from sglang.cli.main import get_is_diffusion_model, get_model_path
from sglang.srt.utils import kill_process_tree
logger = logging.getLogger(__name__)
def serve(args, extra_argv):
model_path = get_model_path(extra_argv)
try:
is_diffusion_model = get_is_diffusion_model(model_path)
if is_diffusion_model:
# Logic for Diffusion Models
from sglang.multimodal_gen.runtime.entrypoints.cli.serve import (
add_multimodal_gen_serve_args,
execute_serve_cmd,
)
parser = argparse.ArgumentParser(
description="SGLang Diffusion Model Serving"
)
add_multimodal_gen_serve_args(parser)
parsed_args, remaining_argv = parser.parse_known_args(extra_argv)
execute_serve_cmd(parsed_args, remaining_argv)
else:
# Logic for Standard Language Models
from sglang.launch_server import run_server
from sglang.srt.server_args import prepare_server_args
# Add a dummy argument for the program name, expected by prepare_server_args
# as it typically processes sys.argv
server_args = prepare_server_args(extra_argv)
run_server(server_args)
finally:
kill_process_tree(os.getpid(), include_parent=False)
......@@ -7,19 +7,23 @@ import sys
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree
if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:])
try:
def run_server(server_args):
"""Run the server based on server_args.grpc_mode."""
if server_args.grpc_mode:
# Handle gRPC server
from sglang.srt.entrypoints.grpc_server import serve_grpc
asyncio.run(serve_grpc(server_args))
else:
# Handle HTTP server
from sglang.srt.entrypoints.http_server import launch_server
launch_server(server_args)
if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:])
try:
run_server(server_args)
finally:
kill_process_tree(os.getpid(), include_parent=False)
<div align="center">
<img src=assets/logos/logo.svg width="30%"/>
</div>
**sgl-diffusion is an inference framework for accelerated image/video generation.**
sgl-diffusion features an end-to-end unified pipeline for accelerating diffusion models. It is designed to be modular and extensible, allowing users to easily add new optimizations and techniques.
## Key Features
sgl-diffusion has the following features:
- State-of-the-art performance optimizations for inference
- [Video Sparse Attention](https://arxiv.org/pdf/2505.13389)
- [Sliding Tile Attention](https://arxiv.org/pdf/2502.04507)
- [TeaCache](https://arxiv.org/pdf/2411.19108)
- [Sage Attention](https://arxiv.org/abs/2410.02367)
- USP
- CFG Parallel
- Diverse hardware and OS support
- Supported hardware: H100, H200, A100, B200, 4090
- Supported OS: Linux, Windows, MacOS
## Getting Started
```bash
uv pip install sglang[.diffusion] --prerelease=allow
```
For more information, check the [docs](https://github.com/sgl-project/sglang/tree/main/python/sglang/multimodal_gen/docs/install.md).
## Inference
Here's a minimal example to generate a video using the default settings:
```python
from sglang.multimodal_gen import DiffGenerator
def main():
# Create a diff generator from a pre-trained model
generator = DiffGenerator.from_pretrained(
model_path="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
num_gpus=1, # Adjust based on your hardware
)
# Provide a prompt for your video
prompt = "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest."
# Generate the video
video = generator.generate(
prompt,
return_frames=True, # Also return frames from this call (defaults to False)
output_path="my_videos/", # Controls where videos are saved
save_output=True
)
if __name__ == '__main__':
main()
```
Or, more simply, with the CLI:
```bash
sglang generate --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
--text-encoder-cpu-offload --pin-cpu-memory \
--prompt "A curious raccoon" \
--save-output
```
For more information, check the [docs](https://github.com/sgl-project/sglang/tree/main/python/sglang/multimodal_gen/docs/cli.md).
## Contributing
All contributions are welcome.
## Acknowledgement
We learnt and reused code from the following projects:
- [FastVideo](https://github.com/hao-ai-lab/FastVideo.git). The major components of this repo are based on a fork of FastVide on Sept. 24, 2025.
- [xDiT](https://github.com/xdit-project/xDiT). We used the parallelism library from it.
- [diffusers](https://github.com/huggingface/diffusers) We used the pipeline design from it.
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from sglang.multimodal_gen.configs.pipelines import PipelineConfig
from sglang.multimodal_gen.configs.sample import SamplingParams
from sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator
__all__ = ["DiffGenerator", "PipelineConfig", "SamplingParams"]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Configs for pipelines, and pipeline modules (in models folder)
{
"temporal_chunk_size": 2,
"temporal_topk": 2,
"spatial_chunk_size": [4, 13],
"spatial_topk": 6,
"st_chunk_size": [4, 4, 13],
"st_topk": 18,
"moba_select_mode": "topk",
"moba_threshold": 0.25,
"moba_threshold_type": "query_head",
"first_full_layer": 0,
"first_full_step": 12,
"temporal_layer": 1,
"spatial_layer": 1,
"st_layer": 1
}
{
"temporal_chunk_size": 2,
"temporal_topk": 3,
"spatial_chunk_size": [3, 4],
"spatial_topk": 20,
"st_chunk_size": [4, 6, 4],
"st_topk": 15,
"moba_select_mode": "threshold",
"moba_threshold": 0.25,
"moba_threshold_type": "query_head",
"first_full_layer": 0,
"first_full_step": 12,
"temporal_layer": 1,
"spatial_layer": 1,
"st_layer": 1
}
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import dataclasses
from enum import Enum
from typing import Any, Optional
from sglang.multimodal_gen.configs.utils import update_config_from_args
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import FlexibleArgumentParser, StoreBoolean
logger = init_logger(__name__)
class DatasetType(str, Enum):
"""
Enumeration for different dataset types.
"""
HF = "hf"
MERGED = "merged"
@classmethod
def from_string(cls, value: str) -> "DatasetType":
"""Convert string to DatasetType enum."""
try:
return cls(value.lower())
except ValueError:
raise ValueError(
f"Invalid dataset type: {value}. Must be one of: {', '.join([m.value for m in cls])}"
) from None
@classmethod
def choices(cls) -> list[str]:
"""Get all available choices as strings for argparse."""
return [dataset_type.value for dataset_type in cls]
class VideoLoaderType(str, Enum):
"""
Enumeration for different video loaders.
"""
TORCHCODEC = "torchcodec"
TORCHVISION = "torchvision"
@classmethod
def from_string(cls, value: str) -> "VideoLoaderType":
"""Convert string to VideoLoader enum."""
try:
return cls(value.lower())
except ValueError:
raise ValueError(
f"Invalid video loader: {value}. Must be one of: {', '.join([m.value for m in cls])}"
) from None
@classmethod
def choices(cls) -> list[str]:
"""Get all available choices as strings for argparse."""
return [video_loader.value for video_loader in cls]
@dataclasses.dataclass
class PreprocessConfig:
"""Configuration for preprocessing operations."""
# Model and dataset configuration
model_path: str = ""
dataset_path: str = ""
dataset_type: DatasetType = DatasetType.HF
dataset_output_dir: str = "./output"
# Dataloader configuration
dataloader_num_workers: int = 1
preprocess_video_batch_size: int = 2
# Saver configuration
samples_per_file: int = 64
flush_frequency: int = 256
# Video processing parameters
video_loader_type: VideoLoaderType = VideoLoaderType.TORCHCODEC
max_height: int = 480
max_width: int = 848
num_frames: int = 163
video_length_tolerance_range: float = 2.0
train_fps: int = 30
speed_factor: float = 1.0
drop_short_ratio: float = 1.0
do_temporal_sample: bool = False
# Model configuration
training_cfg_rate: float = 0.0
# framework configuration
seed: int = 42
@staticmethod
def add_cli_args(
parser: FlexibleArgumentParser, prefix: str = "preprocess"
) -> FlexibleArgumentParser:
"""Add preprocessing configuration arguments to the parser."""
prefix_with_dot = f"{prefix}." if (prefix.strip() != "") else ""
preprocess_args = parser.add_argument_group("Preprocessing Arguments")
# Model & Dataset
preprocess_args.add_argument(
f"--{prefix_with_dot}model-path",
type=str,
default=PreprocessConfig.model_path,
help="Path to the model for preprocessing",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}dataset-path",
type=str,
default=PreprocessConfig.dataset_path,
help="Path to the dataset directory for preprocessing",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}dataset-type",
type=str,
choices=DatasetType.choices(),
default=PreprocessConfig.dataset_type.value,
help="Type of the dataset",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}dataset-output-dir",
type=str,
default=PreprocessConfig.dataset_output_dir,
help="The output directory where the dataset will be written.",
)
# Dataloader
preprocess_args.add_argument(
f"--{prefix_with_dot}dataloader-num-workers",
type=int,
default=PreprocessConfig.dataloader_num_workers,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}preprocess-video-batch-size",
type=int,
default=PreprocessConfig.preprocess_video_batch_size,
help="Batch size (per device) for the training dataloader.",
)
# Saver
preprocess_args.add_argument(
f"--{prefix_with_dot}samples-per-file",
type=int,
default=PreprocessConfig.samples_per_file,
help="Number of samples per output file",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}flush-frequency",
type=int,
default=PreprocessConfig.flush_frequency,
help="How often to save to parquet files",
)
# Video processing parameters
preprocess_args.add_argument(
f"--{prefix_with_dot}video-loader-type",
type=str,
choices=VideoLoaderType.choices(),
default=PreprocessConfig.video_loader_type.value,
help="Type of the video loader",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}max-height",
type=int,
default=PreprocessConfig.max_height,
help="Maximum height for video processing",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}max-width",
type=int,
default=PreprocessConfig.max_width,
help="Maximum width for video processing",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}num-frames",
type=int,
default=PreprocessConfig.num_frames,
help="Number of frames to process",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}video-length-tolerance-range",
type=float,
default=PreprocessConfig.video_length_tolerance_range,
help="Video length tolerance range",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}train-fps",
type=int,
default=PreprocessConfig.train_fps,
help="Training FPS",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}speed-factor",
type=float,
default=PreprocessConfig.speed_factor,
help="Speed factor for video processing",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}drop-short-ratio",
type=float,
default=PreprocessConfig.drop_short_ratio,
help="Ratio for dropping short videos",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}do-temporal-sample",
action=StoreBoolean,
default=PreprocessConfig.do_temporal_sample,
help="Whether to do temporal sampling",
)
# Model Training configuration
preprocess_args.add_argument(
f"--{prefix_with_dot}training-cfg-rate",
type=float,
default=PreprocessConfig.training_cfg_rate,
help="Training CFG rate",
)
preprocess_args.add_argument(
f"--{prefix_with_dot}seed",
type=int,
default=PreprocessConfig.seed,
help="Seed for random number generator",
)
return parser
@classmethod
def from_kwargs(cls, kwargs: dict[str, Any]) -> Optional["PreprocessConfig"]:
"""Create PreprocessConfig from keyword arguments."""
if "dataset_type" in kwargs and isinstance(kwargs["dataset_type"], str):
kwargs["dataset_type"] = DatasetType.from_string(kwargs["dataset_type"])
if "video_loader_type" in kwargs and isinstance(
kwargs["video_loader_type"], str
):
kwargs["video_loader_type"] = VideoLoaderType.from_string(
kwargs["video_loader_type"]
)
preprocess_config = cls()
if not update_config_from_args(
preprocess_config, kwargs, prefix="preprocess", pop_args=True
):
return None
return preprocess_config
def check_preprocess_config(self) -> None:
if self.dataset_path == "":
raise ValueError("dataset_path must be set for preprocess mode")
if self.samples_per_file <= 0:
raise ValueError("samples_per_file must be greater than 0")
if self.flush_frequency <= 0:
raise ValueError("flush_frequency must be greater than 0")
{
"embedded_cfg_scale": 6,
"flow_shift": 17,
"dit_cpu_offload": false,
"disable_autocast": false,
"precision": "bf16",
"vae_precision": "fp32",
"vae_tiling": true,
"vae_sp": true,
"vae_config": {
"load_encoder": false,
"load_decoder": true,
"tile_sample_min_height": 256,
"tile_sample_min_width": 256,
"tile_sample_min_num_frames": 16,
"tile_sample_stride_height": 192,
"tile_sample_stride_width": 192,
"tile_sample_stride_num_frames": 12,
"blend_num_frames": 4,
"use_tiling": true,
"use_temporal_tiling": true,
"use_parallel_tiling": true
},
"dit_config": {
"prefix": "Hunyuan",
"quant_config": null
},
"text_encoder_precisions": [
"fp16",
"fp16"
],
"text_encoder_configs": [
{
"prefix": "llama",
"quant_config": null,
"lora_config": null
},
{
"prefix": "clip",
"quant_config": null,
"lora_config": null,
"num_hidden_layers_override": null,
"require_post_norm": null
}
],
"mask_strategy_file_path": null,
"enable_torch_compile": false
}
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from sglang.multimodal_gen.configs.models.base import ModelConfig
from sglang.multimodal_gen.configs.models.dits.base import DiTConfig
from sglang.multimodal_gen.configs.models.encoders.base import EncoderConfig
from sglang.multimodal_gen.configs.models.vaes.base import VAEConfig
__all__ = ["ModelConfig", "VAEConfig", "DiTConfig", "EncoderConfig"]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field, fields
from typing import Any, Dict
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
# 1. ArchConfig contains all fields from diffuser's/transformer's config.json (i.e. all fields related to the architecture of the model)
# 2. ArchConfig should be inherited & overridden by each model arch_config
# 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users
@dataclass
class ArchConfig:
stacked_params_mapping: list[tuple[str, str, str]] = field(
default_factory=list
) # mapping from huggingface weight names to custom names
extra_attrs: Dict[str, Any] = field(default_factory=dict)
def __getattr__(self, name: str):
d = object.__getattribute__(self, "__dict__")
extras = d.get("extra_attrs")
if extras is not None and name in extras:
return extras[name]
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
def __setattr__(self, key, value):
if key in type(self).__dataclass_fields__:
object.__setattr__(self, key, value)
else:
d = object.__getattribute__(self, "__dict__")
extras = d.get("extra_attrs")
if extras is None:
extras = {}
d["extra_attrs"] = extras
extras[key] = value
@dataclass
class ModelConfig:
# Every model config parameter can be categorized into either ArchConfig or everything else
# Diffuser/Transformer parameters
arch_config: ArchConfig = field(default_factory=ArchConfig)
# sgl-diffusion-specific parameters here
# i.e. STA, quantization, teacache
def __getattr__(self, name):
# Only called if 'name' is not found in ModelConfig directly
if hasattr(self.arch_config, name):
return getattr(self.arch_config, name)
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
def __getstate__(self):
# Return a dictionary of attributes to pickle
# Convert to dict and exclude any problematic attributes
state = self.__dict__.copy()
return state
def __setstate__(self, state):
# Restore instance attributes from the unpickled state
self.__dict__.update(state)
# This should be used only when loading from transformers/diffusers
def update_model_arch(self, source_model_dict: dict[str, Any]) -> None:
"""
Update arch_config with source_model_dict
"""
arch_config = self.arch_config
valid_fields = {f.name for f in fields(arch_config)}
for key, value in source_model_dict.items():
setattr(arch_config, key, value)
# else:
# raise AttributeError(
# f"{type(arch_config).__name__} has no field '{key}'"
# )
if hasattr(arch_config, "__post_init__"):
arch_config.__post_init__()
def update_model_config(self, source_model_dict: dict[str, Any]) -> None:
assert (
"arch_config" not in source_model_dict
), "Source model config shouldn't contain arch_config."
valid_fields = {f.name for f in fields(self)}
for key, value in source_model_dict.items():
if key in valid_fields:
setattr(self, key, value)
else:
logger.warning(
"%s does not contain field '%s'!", type(self).__name__, key
)
raise AttributeError(f"Invalid field: {key}")
if hasattr(self, "__post_init__"):
self.__post_init__()
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from sglang.multimodal_gen.configs.models.dits.hunyuanvideo import HunyuanVideoConfig
from sglang.multimodal_gen.configs.models.dits.stepvideo import StepVideoConfig
from sglang.multimodal_gen.configs.models.dits.wanvideo import WanVideoConfig
__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig"]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Any
from sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig
from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
@dataclass
class DiTArchConfig(ArchConfig):
_fsdp_shard_conditions: list = field(default_factory=list)
_compile_conditions: list = field(default_factory=list)
param_names_mapping: dict = field(default_factory=dict)
reverse_param_names_mapping: dict = field(default_factory=dict)
lora_param_names_mapping: dict = field(default_factory=dict)
_supported_attention_backends: set[AttentionBackendEnum] = field(
default_factory=lambda: {
AttentionBackendEnum.SLIDING_TILE_ATTN,
AttentionBackendEnum.SAGE_ATTN,
AttentionBackendEnum.FA3,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.VIDEO_SPARSE_ATTN,
AttentionBackendEnum.VMOBA_ATTN,
AttentionBackendEnum.SAGE_ATTN_THREE,
}
)
hidden_size: int = 0
num_attention_heads: int = 0
num_channels_latents: int = 0
exclude_lora_layers: list[str] = field(default_factory=list)
boundary_ratio: float | None = None
def __post_init__(self) -> None:
if not self._compile_conditions:
self._compile_conditions = self._fsdp_shard_conditions.copy()
@dataclass
class DiTConfig(ModelConfig):
arch_config: DiTArchConfig = field(default_factory=DiTArchConfig)
# sgl-diffusionDiT-specific parameters
prefix: str = ""
quant_config: QuantizationConfig | None = None
@staticmethod
def add_cli_args(parser: Any, prefix: str = "dit-config") -> Any:
"""Add CLI arguments for DiTConfig fields"""
parser.add_argument(
f"--{prefix}.prefix",
type=str,
dest=f"{prefix.replace('-', '_')}.prefix",
default=DiTConfig.prefix,
help="Prefix for the DiT model",
)
parser.add_argument(
f"--{prefix}.quant-config",
type=str,
dest=f"{prefix.replace('-', '_')}.quant_config",
default=None,
help="Quantization configuration for the DiT model",
)
return parser
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Tuple
from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig
@dataclass
class FluxArchConfig(DiTArchConfig):
patch_size: int = 1
in_channels: int = 64
out_channels: int | None = None
num_layers: int = 19
num_single_layers: int = 38
attention_head_dim: int = 128
num_attention_heads: int = 24
joint_attention_dim: int = 4096
pooled_projection_dim: int = 768
guidance_embeds: bool = False
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)
def __post_init__(self):
super().__post_init__()
self.out_channels = self.out_channels or self.in_channels
self.hidden_size = self.num_attention_heads * self.attention_head_dim
self.num_channels_latents = self.out_channels
@dataclass
class FluxConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=FluxArchConfig)
prefix: str = "Flux"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment