Unverified Commit 7bc1dae0 authored by Mick's avatar Mick Committed by GitHub
Browse files

WIP: initial multimodal-gen support (#12484)


Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarJiLi <leege233@gmail.com>
Co-authored-by: default avatarCHEN Xi <78632976+RubiaCx@users.noreply.github.com>
Co-authored-by: default avatarlaixin <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarSolitaryThinker <wlsaidhi@gmail.com>
Co-authored-by: default avatarjzhang38 <a1286225768@gmail.com>
Co-authored-by: default avatarBrianChen1129 <yongqichcd@gmail.com>
Co-authored-by: default avatarKevin Lin <42618777+kevin314@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarrlsu9 <r3su@ucsd.edu>
Co-authored-by: default avatarJinzhe Pan <48981407+eigensystem@users.noreply.github.com>
Co-authored-by: default avatarforeverpiano <pianoqwz@qq.com>
Co-authored-by: default avatarRandNMR73 <notomatthew31@gmail.com>
Co-authored-by: default avatarPorridgeSwim <yz3883@columbia.edu>
Co-authored-by: default avatarJiali Chen <90408393+gary-chenjl@users.noreply.github.com>
parent 4fe53e58
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import torch # type: ignore
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context
from sglang.multimodal_gen.runtime.models.utils import pred_noise_to_pred_video
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages.denoising import DenoisingStage
from sglang.multimodal_gen.runtime.pipelines.stages.validators import (
StageValidators as V,
)
from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
try:
from sglang.multimodal_gen.runtime.layers.attention.backends.sliding_tile_attn import (
SlidingTileAttentionBackend,
)
st_attn_available = True
except ImportError:
st_attn_available = False
SlidingTileAttentionBackend = None # type: ignore
try:
from sglang.multimodal_gen.runtime.layers.attention.backends.video_sparse_attn import (
VideoSparseAttentionBackend,
)
vsa_available = True
except ImportError:
vsa_available = False
VideoSparseAttentionBackend = None # type: ignore
logger = init_logger(__name__)
class CausalDMDDenoisingStage(DenoisingStage):
"""
Denoising stage for causal diffusion.
"""
def __init__(self, transformer, scheduler) -> None:
super().__init__(transformer, scheduler)
# KV and cross-attention cache state (initialized on first forward)
self.kv_cache1: list | None = None
self.crossattn_cache: list | None = None
# Model-dependent constants (aligned with causal_inference.py assumptions)
self.num_transformer_blocks = self.transformer.config.arch_config.num_layers
self.num_frames_per_block = (
self.transformer.config.arch_config.num_frames_per_block
)
self.sliding_window_num_frames = (
self.transformer.config.arch_config.sliding_window_num_frames
)
try:
self.local_attn_size = getattr(
self.transformer.model, "local_attn_size", -1
) # type: ignore
except Exception:
self.local_attn_size = -1
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
target_dtype = torch.bfloat16
autocast_enabled = (
target_dtype != torch.float32
) and not server_args.disable_autocast
latent_seq_length = batch.latents.shape[-1] * batch.latents.shape[-2]
patch_ratio = (
self.transformer.config.arch_config.patch_size[-1]
* self.transformer.config.arch_config.patch_size[-2]
)
self.frame_seq_length = latent_seq_length // patch_ratio
# TODO(will): make this a parameter once we add i2v support
independent_first_frame = self.transformer.independent_first_frame
# Timesteps for DMD
timesteps = torch.tensor(
server_args.pipeline_config.dmd_denoising_steps, dtype=torch.long
).cpu()
if server_args.pipeline_config.warp_denoising_step:
logger.info("Warping timesteps...")
scheduler_timesteps = torch.cat(
(self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32))
)
timesteps = scheduler_timesteps[1000 - timesteps]
timesteps = timesteps.to(get_local_torch_device())
logger.info("Using timesteps: %s", timesteps)
# Image kwargs (kept empty unless caller provides compatible args)
image_kwargs: dict = {}
pos_cond_kwargs = self.prepare_extra_func_kwargs(
self.transformer.forward,
{
# "encoder_hidden_states_2": batch.clip_embedding_pos,
"encoder_attention_mask": batch.prompt_attention_mask,
},
)
# STA
if st_attn_available and self.attn_backend == SlidingTileAttentionBackend:
self.prepare_sta_param(batch, server_args)
# Latents and prompts
assert batch.latents is not None, "latents must be provided"
latents = batch.latents # [B, C, T, H, W]
b, c, t, h, w = latents.shape
prompt_embeds = batch.prompt_embeds
assert torch.isnan(prompt_embeds[0]).sum() == 0
# Initialize or reset caches
if self.kv_cache1 is None:
self._initialize_kv_cache(
batch_size=latents.shape[0], dtype=target_dtype, device=latents.device
)
self._initialize_crossattn_cache(
batch_size=latents.shape[0],
max_text_len=server_args.pipeline_config.text_encoder_configs[
0
].arch_config.text_len,
dtype=target_dtype,
device=latents.device,
)
else:
assert self.crossattn_cache is not None
# reset cross-attention cache
for block_index in range(self.num_transformer_blocks):
self.crossattn_cache[block_index]["is_init"] = False # type: ignore
# reset kv cache pointers
for block_index in range(len(self.kv_cache1)):
self.kv_cache1[block_index]["global_end_index"] = (
torch.tensor( # type: ignore
[0], dtype=torch.long, device=latents.device
)
)
self.kv_cache1[block_index]["local_end_index"] = (
torch.tensor( # type: ignore
[0], dtype=torch.long, device=latents.device
)
)
# Optional: cache context features from provided image latents prior to generation
current_start_frame = 0
if getattr(batch, "image_latent", None) is not None:
image_latent = batch.image_latent
assert image_latent is not None
input_frames = image_latent.shape[2]
# timestep zero (or configured context noise) for cache warm-up
t_zero = torch.zeros(
[latents.shape[0]], device=latents.device, dtype=torch.long
)
if independent_first_frame and input_frames >= 1:
# warm-up with the very first frame independently
image_first_btchw = (
image_latent[:, :, :1, :, :].to(target_dtype).permute(0, 2, 1, 3, 4)
)
with torch.autocast(
device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
):
_ = self.transformer(
image_first_btchw,
prompt_embeds,
t_zero,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length,
**image_kwargs,
**pos_cond_kwargs,
)
current_start_frame += 1
remaining_frames = input_frames - 1
else:
remaining_frames = input_frames
# process remaining input frames in blocks of num_frame_per_block
while remaining_frames > 0:
block = min(self.num_frames_per_block, remaining_frames)
ref_btchw = (
image_latent[
:, :, current_start_frame : current_start_frame + block, :, :
]
.to(target_dtype)
.permute(0, 2, 1, 3, 4)
)
with torch.autocast(
device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
):
_ = self.transformer(
ref_btchw,
prompt_embeds,
t_zero,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length,
**image_kwargs,
**pos_cond_kwargs,
)
current_start_frame += block
remaining_frames -= block
# Base position offset from any cache warm-up
pos_start_base = current_start_frame
# Determine block sizes
if not independent_first_frame or (
independent_first_frame and batch.image_latent is not None
):
if t % self.num_frames_per_block != 0:
raise ValueError(
"num_frames must be divisible by num_frames_per_block for causal DMD denoising"
)
num_blocks = t // self.num_frames_per_block
block_sizes = [self.num_frames_per_block] * num_blocks
start_index = 0
else:
if (t - 1) % self.num_frames_per_block != 0:
raise ValueError(
"(num_frames - 1) must be divisible by num_frame_per_block when independent_first_frame=True"
)
num_blocks = (t - 1) // self.num_frames_per_block
block_sizes = [1] + [self.num_frames_per_block] * num_blocks
start_index = 0
# DMD loop in causal blocks
with self.progress_bar(total=len(block_sizes) * len(timesteps)) as progress_bar:
for current_num_frames in block_sizes:
current_latents = latents[
:, :, start_index : start_index + current_num_frames, :, :
]
# use BTCHW for DMD conversion routines
noise_latents_btchw = current_latents.permute(0, 2, 1, 3, 4)
video_raw_latent_shape = noise_latents_btchw.shape
for i, t_cur in enumerate(timesteps):
# Copy for pred conversion
noise_latents = noise_latents_btchw.clone()
latent_model_input = current_latents.to(target_dtype)
if (
batch.image_latent is not None
and independent_first_frame
and start_index == 0
):
latent_model_input = torch.cat(
[latent_model_input, batch.image_latent.to(target_dtype)],
dim=2,
)
# Prepare inputs
t_expand = t_cur.repeat(latent_model_input.shape[0])
# Attention metadata if needed
if (
vsa_available
and self.attn_backend == VideoSparseAttentionBackend
):
self.attn_metadata_builder_cls = (
self.attn_backend.get_builder_cls()
)
if self.attn_metadata_builder_cls is not None:
self.attn_metadata_builder = (
self.attn_metadata_builder_cls()
)
attn_metadata = self.attn_metadata_builder.build( # type: ignore
current_timestep=i, # type: ignore
raw_latent_shape=(
current_num_frames,
h,
w,
), # type: ignore
patch_size=server_args.pipeline_config.dit_config.patch_size, # type: ignore
STA_param=batch.STA_param, # type: ignore
VSA_sparsity=server_args.VSA_sparsity, # type: ignore
device=get_local_torch_device(), # type: ignore
) # type: ignore
assert (
attn_metadata is not None
), "attn_metadata cannot be None"
else:
attn_metadata = None
else:
attn_metadata = None
with (
torch.autocast(
device_type="cuda",
dtype=target_dtype,
enabled=autocast_enabled,
),
set_forward_context(
current_timestep=i,
attn_metadata=attn_metadata,
forward_batch=batch,
),
):
# Run transformer; follow DMD stage pattern
t_expanded_noise = t_cur * torch.ones(
(latent_model_input.shape[0], 1),
device=latent_model_input.device,
dtype=torch.long,
)
pred_noise_btchw = self.transformer(
latent_model_input,
prompt_embeds,
t_expanded_noise,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=(pos_start_base + start_index)
* self.frame_seq_length,
start_frame=start_index,
**image_kwargs,
**pos_cond_kwargs,
).permute(0, 2, 1, 3, 4)
# Convert pred noise to pred video with FM Euler scheduler utilities
pred_video_btchw = pred_noise_to_pred_video(
pred_noise=pred_noise_btchw.flatten(0, 1),
noise_input_latent=noise_latents.flatten(0, 1),
timestep=t_expand,
scheduler=self.scheduler,
).unflatten(0, pred_noise_btchw.shape[:2])
if i < len(timesteps) - 1:
next_timestep = timesteps[i + 1] * torch.ones(
[1], dtype=torch.long, device=pred_video_btchw.device
)
noise = torch.randn(
video_raw_latent_shape,
dtype=pred_video_btchw.dtype,
generator=(
batch.generator[0]
if isinstance(batch.generator, list)
else batch.generator
),
).to(self.device)
noise_btchw = noise
noise_latents_btchw = self.scheduler.add_noise(
pred_video_btchw.flatten(0, 1),
noise_btchw.flatten(0, 1),
next_timestep,
).unflatten(0, pred_video_btchw.shape[:2])
current_latents = noise_latents_btchw.permute(0, 2, 1, 3, 4)
else:
current_latents = pred_video_btchw.permute(0, 2, 1, 3, 4)
if progress_bar is not None:
progress_bar.update()
# Write back and advance
latents[:, :, start_index : start_index + current_num_frames, :, :] = (
current_latents
)
# Re-run with context timestep to update KV cache using clean context
context_noise = getattr(server_args.pipeline_config, "context_noise", 0)
t_context = torch.ones(
[latents.shape[0]], device=latents.device, dtype=torch.long
) * int(context_noise)
context_bcthw = current_latents.to(target_dtype)
with (
torch.autocast(
device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
),
set_forward_context(
current_timestep=0,
attn_metadata=attn_metadata,
forward_batch=batch,
),
):
t_expanded_context = t_context.unsqueeze(1)
_ = self.transformer(
context_bcthw,
prompt_embeds,
t_expanded_context,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=(pos_start_base + start_index)
* self.frame_seq_length,
start_frame=start_index,
**image_kwargs,
**pos_cond_kwargs,
)
start_index += current_num_frames
batch.latents = latents
return batch
def _initialize_kv_cache(self, batch_size, dtype, device) -> None:
"""
Initialize a Per-GPU KV cache aligned with the Wan model assumptions.
"""
kv_cache1 = []
num_attention_heads = self.transformer.num_attention_heads
attention_head_dim = self.transformer.attention_head_dim
if self.local_attn_size != -1:
kv_cache_size = self.local_attn_size * self.frame_seq_length
else:
kv_cache_size = self.frame_seq_length * self.sliding_window_num_frames
for _ in range(self.num_transformer_blocks):
kv_cache1.append(
{
"k": torch.zeros(
[
batch_size,
kv_cache_size,
num_attention_heads,
attention_head_dim,
],
dtype=dtype,
device=device,
),
"v": torch.zeros(
[
batch_size,
kv_cache_size,
num_attention_heads,
attention_head_dim,
],
dtype=dtype,
device=device,
),
"global_end_index": torch.tensor(
[0], dtype=torch.long, device=device
),
"local_end_index": torch.tensor(
[0], dtype=torch.long, device=device
),
}
)
self.kv_cache1 = kv_cache1
def _initialize_crossattn_cache(
self, batch_size, max_text_len, dtype, device
) -> None:
"""
Initialize a Per-GPU cross-attention cache aligned with the Wan model assumptions.
"""
crossattn_cache = []
num_attention_heads = self.transformer.num_attention_heads
attention_head_dim = self.transformer.attention_head_dim
for _ in range(self.num_transformer_blocks):
crossattn_cache.append(
{
"k": torch.zeros(
[
batch_size,
max_text_len,
num_attention_heads,
attention_head_dim,
],
dtype=dtype,
device=device,
),
"v": torch.zeros(
[
batch_size,
max_text_len,
num_attention_heads,
attention_head_dim,
],
dtype=dtype,
device=device,
),
"is_init": False,
}
)
self.crossattn_cache = crossattn_cache
def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify denoising stage inputs."""
result = VerificationResult()
result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
result.add_check("image_embeds", batch.image_embeds, V.is_list)
result.add_check(
"image_latent", batch.image_latent, V.none_or_tensor_with_dims(5)
)
result.add_check(
"num_inference_steps", batch.num_inference_steps, V.positive_int
)
result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
result.add_check("eta", batch.eta, V.non_negative_float)
result.add_check("generator", batch.generator, V.generator_or_list_generators)
result.add_check(
"do_classifier_free_guidance",
batch.do_classifier_free_guidance,
V.bool_value,
)
result.add_check(
"negative_prompt_embeds",
batch.negative_prompt_embeds,
lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x),
)
return result
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Conditioning stage for diffusion pipelines.
"""
import torch
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage
from sglang.multimodal_gen.runtime.pipelines.stages.validators import (
StageValidators as V,
)
from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class ConditioningStage(PipelineStage):
"""
Stage for applying conditioning to the diffusion process.
This stage handles the application of conditioning, such as classifier-free guidance,
to the diffusion process.
"""
@torch.no_grad()
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Apply conditioning to the diffusion process.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
The batch with applied conditioning.
"""
# TODO!!
if not batch.do_classifier_free_guidance:
return batch
else:
return batch
logger.info("batch.negative_prompt_embeds: %s", batch.negative_prompt_embeds)
logger.info(
"do_classifier_free_guidance: %s", batch.do_classifier_free_guidance
)
logger.info("cfg_scale: %s", batch.guidance_scale)
# Ensure negative prompt embeddings are available
assert (
batch.negative_prompt_embeds is not None
), "Negative prompt embeddings are required for classifier-free guidance"
# Concatenate primary embeddings and masks
batch.prompt_embeds = torch.cat(
[batch.negative_prompt_embeds, batch.prompt_embeds]
)
if batch.attention_mask is not None:
batch.attention_mask = torch.cat(
[batch.negative_attention_mask, batch.attention_mask]
)
# Concatenate secondary embeddings and masks if present
if batch.prompt_embeds_2 is not None:
batch.prompt_embeds_2 = torch.cat(
[batch.negative_prompt_embeds_2, batch.prompt_embeds_2]
)
if batch.attention_mask_2 is not None:
batch.attention_mask_2 = torch.cat(
[batch.negative_attention_mask_2, batch.attention_mask_2]
)
return batch
def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify conditioning stage inputs."""
result = VerificationResult()
result.add_check(
"do_classifier_free_guidance",
batch.do_classifier_free_guidance,
V.bool_value,
)
result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
result.add_check(
"negative_prompt_embeds",
batch.negative_prompt_embeds,
lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x),
)
return result
def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify conditioning stage outputs."""
result = VerificationResult()
result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
return result
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Decoding stage for diffusion pipelines.
"""
import weakref
import torch
from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig
from sglang.multimodal_gen.configs.pipelines.qwen_image import (
QwenImageEditPipelineConfig,
QwenImagePipelineConfig,
)
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.loader.component_loader import VAELoader
from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import OutputBatch, Req
from sglang.multimodal_gen.runtime.pipelines.stages.base import (
PipelineStage,
StageParallelismType,
)
from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult
from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import PRECISION_TO_TYPE
logger = init_logger(__name__)
class DecodingStage(PipelineStage):
"""
Stage for decoding latent representations into pixel space.
This stage handles the decoding of latent representations into the final
output format (e.g., pixel values).
"""
def __init__(self, vae, pipeline=None) -> None:
self.vae: ParallelTiledVAE = vae
self.pipeline = weakref.ref(pipeline) if pipeline else None
@property
def parallelism_type(self) -> StageParallelismType:
if get_global_server_args().enable_cfg_parallel:
return StageParallelismType.MAIN_RANK_ONLY
return StageParallelismType.REPLICATED
def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify decoding stage inputs."""
result = VerificationResult()
# Denoised latents for VAE decoding: [batch_size, channels, frames, height_latents, width_latents]
# result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
return result
def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify decoding stage outputs."""
result = VerificationResult()
# Decoded video/images: [batch_size, channels, frames, height, width]
# result.add_check("output", batch.output, [V.is_tensor, V.with_dims(5)])
return result
def scale_and_shift(
self, vae_arch_config: VAEArchConfig, latents: torch.Tensor, server_args
):
# 1. scale
is_qwen_image = isinstance(
server_args.pipeline_config, QwenImagePipelineConfig
) or isinstance(server_args.pipeline_config, QwenImageEditPipelineConfig)
if is_qwen_image:
scaling_factor = 1.0 / torch.tensor(
vae_arch_config.latents_std, device=latents.device
).view(1, vae_arch_config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
else:
scaling_factor = vae_arch_config.scaling_factor
if isinstance(scaling_factor, torch.Tensor):
latents = latents / scaling_factor.to(latents.device, latents.dtype)
else:
latents = latents / scaling_factor
# 2. shift
if is_qwen_image:
shift_factor = (
torch.tensor(vae_arch_config.latents_mean)
.view(1, vae_arch_config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
else:
shift_factor = getattr(vae_arch_config, "shift_factor", None)
# Apply shifting if needed
if shift_factor is not None:
if isinstance(shift_factor, torch.Tensor):
latents += shift_factor.to(latents.device, latents.dtype)
else:
latents += shift_factor
return latents
@torch.no_grad()
def decode(self, latents: torch.Tensor, server_args: ServerArgs) -> torch.Tensor:
"""
Decode latent representations into pixel space using VAE.
Args:
latents: Input latent tensor with shape (batch, channels, frames, height_latents, width_latents)
server_args: Configuration containing:
- disable_autocast: Whether to disable automatic mixed precision (default: False)
- pipeline_config.vae_precision: VAE computation precision ("fp32", "fp16", "bf16")
- pipeline_config.vae_tiling: Whether to enable VAE tiling for memory efficiency
Returns:
Decoded video tensor with shape (batch, channels, frames, height, width),
normalized to [0, 1] range and moved to CPU as float32
"""
self.vae = self.vae.to(get_local_torch_device())
latents = latents.to(get_local_torch_device())
# Setup VAE precision
vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]
vae_autocast_enabled = (
vae_dtype != torch.float32
) and not server_args.disable_autocast
vae_arch_config = server_args.pipeline_config.vae_config.arch_config
# scale and shift
latents = self.scale_and_shift(vae_arch_config, latents, server_args)
# Decode latents
with torch.autocast(
device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
):
try:
# TODO: make it more specific
if server_args.pipeline_config.vae_tiling:
self.vae.enable_tiling()
except Exception:
pass
if not vae_autocast_enabled:
latents = latents.to(vae_dtype)
image = self.vae.decode(latents)
# De-normalize image to [0, 1] range
image = (image / 2 + 0.5).clamp(0, 1)
return image
@torch.no_grad()
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> OutputBatch:
"""
Decode latent representations into pixel space.
This method processes the batch through the VAE decoder, converting latent
representations to pixel-space video/images. It also optionally decodes
trajectory latents for visualization purposes.
Args:
batch: The current batch containing:
- latents: Tensor to decode (batch, channels, frames, height_latents, width_latents)
- return_trajectory_decoded (optional): Flag to decode trajectory latents
- trajectory_latents (optional): Latents at different timesteps
- trajectory_timesteps (optional): Corresponding timesteps
server_args: Configuration containing:
- output_type: "latent" to skip decoding, otherwise decode to pixels
- vae_cpu_offload: Whether to offload VAE to CPU after decoding
- model_loaded: Track VAE loading state
- model_paths: Path to VAE model if loading needed
Returns:
Modified batch with:
- output: Decoded frames (batch, channels, frames, height, width) as CPU float32
- trajectory_decoded (if requested): List of decoded frames per timestep
"""
# load vae if not already loaded (used for memory constrained devices)
pipeline = self.pipeline() if self.pipeline else None
if not server_args.model_loaded["vae"]:
loader = VAELoader()
self.vae = loader.load(server_args.model_paths["vae"], server_args)
if pipeline:
pipeline.add_module("vae", self.vae)
server_args.model_loaded["vae"] = True
if server_args.output_type == "latent":
frames = batch.latents
else:
frames = self.decode(batch.latents, server_args)
# decode trajectory latents if needed
if batch.return_trajectory_decoded:
trajectory_decoded = []
assert (
batch.trajectory_latents is not None
), "batch should have trajectory latents"
for idx in range(batch.trajectory_latents.shape[1]):
# batch.trajectory_latents is [batch_size, timesteps, channels, frames, height, width]
cur_latent = batch.trajectory_latents[:, idx, :, :, :, :]
cur_timestep = batch.trajectory_timesteps[idx]
logger.info("decoding trajectory latent for timestep: %s", cur_timestep)
decoded_frames = self.decode(cur_latent, server_args)
trajectory_decoded.append(decoded_frames.cpu().float())
else:
trajectory_decoded = None
# Convert to CPU float32 for compatibility
frames = frames.cpu().float()
# Update batch with decoded image
output_batch = OutputBatch(
output=frames,
trajectory_timesteps=batch.trajectory_timesteps,
trajectory_latents=batch.trajectory_latents,
trajectory_decoded=trajectory_decoded,
)
# Offload models if needed
if hasattr(self, "maybe_free_model_hooks"):
self.maybe_free_model_hooks()
if server_args.vae_cpu_offload:
self.vae.to("cpu")
if torch.backends.mps.is_available():
del self.vae
if pipeline is not None and "vae" in pipeline.modules:
del pipeline.modules["vae"]
server_args.model_loaded["vae"] = False
return output_batch
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Denoising stage for diffusion pipelines.
"""
import inspect
import math
import os
import time
import weakref
from collections.abc import Iterable
from functools import lru_cache
from typing import Any
import torch
import torch.profiler
from einops import rearrange
from tqdm.auto import tqdm
from sglang.multimodal_gen.configs.pipelines.base import STA_Mode
from sglang.multimodal_gen.runtime.distributed import (
cfg_model_parallel_all_reduce,
get_local_torch_device,
get_sp_parallel_rank,
get_sp_world_size,
get_world_group,
)
from sglang.multimodal_gen.runtime.distributed.communication_op import (
sequence_model_parallel_all_gather,
)
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_cfg_group,
get_classifier_free_guidance_rank,
)
from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import (
FlashAttentionBackend,
)
from sglang.multimodal_gen.runtime.layers.attention.selector import get_attn_backend
from sglang.multimodal_gen.runtime.layers.attention.STA_configuration import (
configure_sta,
save_mask_search_results,
)
from sglang.multimodal_gen.runtime.loader.component_loader import TransformerLoader
from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages.base import (
PipelineStage,
StageParallelismType,
)
from sglang.multimodal_gen.runtime.pipelines.stages.validators import (
StageValidators as V,
)
from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult
from sglang.multimodal_gen.runtime.platforms.interface import AttentionBackendEnum
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import dict_to_3d_list, masks_like
try:
from sglang.multimodal_gen.runtime.layers.attention.backends.sliding_tile_attn import (
SlidingTileAttentionBackend,
)
st_attn_available = True
except ImportError:
st_attn_available = False
try:
from sglang.multimodal_gen.runtime.layers.attention.backends.vmoba import (
VMOBAAttentionBackend,
)
from sglang.multimodal_gen.utils import is_vmoba_available
vmoba_attn_available = is_vmoba_available()
except ImportError:
vmoba_attn_available = False
try:
from sglang.multimodal_gen.runtime.layers.attention.backends.video_sparse_attn import (
VideoSparseAttentionBackend,
)
vsa_available = True
except ImportError:
vsa_available = False
logger = init_logger(__name__)
class DenoisingStage(PipelineStage):
"""
Stage for running the denoising loop in diffusion pipelines.
This stage handles the iterative denoising process that transforms
the initial noise into the final output.
"""
def __init__(
self, transformer, scheduler, pipeline=None, transformer_2=None, vae=None
) -> None:
super().__init__()
self.transformer = transformer
self.transformer_2 = transformer_2
hidden_size = self.server_args.pipeline_config.dit_config.hidden_size
num_attention_heads = (
self.server_args.pipeline_config.dit_config.num_attention_heads
)
attn_head_size = hidden_size // num_attention_heads
# torch compile
if self.server_args.enable_torch_compile:
full_graph = False
self.transformer = torch.compile(
self.transformer, mode="max-autotune", fullgraph=full_graph
)
self.transformer_2 = (
torch.compile(
self.transformer_2, mode="max-autotune", fullgraph=full_graph
)
if transformer_2 is not None
else None
)
self.scheduler = scheduler
self.vae = vae
self.pipeline = weakref.ref(pipeline) if pipeline else None
self.attn_backend = get_attn_backend(
head_size=attn_head_size,
dtype=torch.float16, # TODO(will): hack
supported_attention_backends={
AttentionBackendEnum.SLIDING_TILE_ATTN,
AttentionBackendEnum.VIDEO_SPARSE_ATTN,
AttentionBackendEnum.VMOBA_ATTN,
AttentionBackendEnum.FA3,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.SAGE_ATTN_THREE,
}, # hack
)
# cfg
self.guidance = None
# misc
self.profiler = None
@lru_cache(maxsize=8)
def _build_guidance(self, batch_size, target_dtype, device, guidance_val):
"""Builds a guidance tensor. This method is cached."""
return (
torch.full(
(batch_size,),
guidance_val,
dtype=torch.float32,
device=device,
).to(target_dtype)
* 1000.0
)
def get_or_build_guidance(self, bsz: int, dtype, device):
"""
Get the guidance tensor, using a cached version if available.
This method retrieves a cached guidance tensor using `_build_guidance`.
The caching is based on batch size, dtype, device, and the guidance value,
preventing repeated tensor creation within the denoising loop.
"""
if self.server_args.pipeline_config.should_use_guidance:
# TODO: should the guidance_scale be picked-up from sampling_params?
guidance_val = self.server_args.pipeline_config.embedded_cfg_scale
return self._build_guidance(bsz, dtype, device, guidance_val)
else:
return None
@property
def parallelism_type(self) -> StageParallelismType:
# return StageParallelismType.CFG_PARALLEL if get_global_server_args().enable_cfg_parallel else StageParallelismType.REPLICATED
return StageParallelismType.REPLICATED
def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):
"""
Prepare all necessary invariant variables for the denoising loop.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
A dictionary containing all the prepared variables for the denoising loop.
"""
pipeline = self.pipeline() if self.pipeline else None
if not server_args.model_loaded["transformer"]:
loader = TransformerLoader()
self.transformer = loader.load(
server_args.model_paths["transformer"], server_args
)
if self.server_args.enable_torch_compile:
self.transformer = torch.compile(
self.transformer, mode="max-autotune", fullgraph=True
)
if pipeline:
pipeline.add_module("transformer", self.transformer)
server_args.model_loaded["transformer"] = True
# Prepare extra step kwargs for scheduler
extra_step_kwargs = self.prepare_extra_func_kwargs(
self.scheduler.step,
{"generator": batch.generator, "eta": batch.eta},
)
# Setup precision and autocast settings
target_dtype = torch.bfloat16
autocast_enabled = (
target_dtype != torch.float32
) and not server_args.disable_autocast
# Handle sequence parallelism if enabled
self._preprocess_sp_latents(batch)
# Get timesteps and calculate warmup steps
timesteps = batch.timesteps
if timesteps is None:
raise ValueError("Timesteps must be provided")
num_inference_steps = batch.num_inference_steps
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# Prepare image latents and embeddings for I2V generation
image_embeds = batch.image_embeds
if len(image_embeds) > 0:
image_embeds = [
image_embed.to(target_dtype) for image_embed in image_embeds
]
# Prepare STA parameters
if st_attn_available and self.attn_backend == SlidingTileAttentionBackend:
self.prepare_sta_param(batch, server_args)
# Get latents and embeddings
latents = batch.latents
prompt_embeds = batch.prompt_embeds
# Removed Tensor truthiness assert to avoid GPU sync
neg_prompt_embeds = None
if batch.do_classifier_free_guidance:
neg_prompt_embeds = batch.negative_prompt_embeds
assert neg_prompt_embeds is not None
# Removed Tensor truthiness assert to avoid GPU sync
# (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
boundary_ratio = server_args.pipeline_config.dit_config.boundary_ratio
if batch.boundary_ratio is not None:
logger.info(
"Overriding boundary ratio from %s to %s",
boundary_ratio,
batch.boundary_ratio,
)
boundary_ratio = batch.boundary_ratio
if boundary_ratio is not None:
boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps
else:
boundary_timestep = None
# TI2V specific preparations
z, mask2, seq_len = None, None, None
# FIXME: should probably move to latent preparation stage, to handle with offload
if server_args.pipeline_config.ti2v_task and batch.pil_image is not None:
# Wan2.2 TI2V directly replaces the first frame of the latent with
# the image latent instead of appending along the channel dim
assert batch.image_latent is None, "TI2V task should not have image latents"
assert self.vae is not None, "VAE is not provided for TI2V task"
self.vae = self.vae.to(batch.pil_image.device)
z = self.vae.encode(batch.pil_image).mean.float()
if self.vae.device != "cpu" and server_args.vae_cpu_offload:
self.vae = self.vae.to("cpu")
if hasattr(self.vae, "shift_factor") and self.vae.shift_factor is not None:
if isinstance(self.vae.shift_factor, torch.Tensor):
z -= self.vae.shift_factor.to(z.device, z.dtype)
else:
z -= self.vae.shift_factor
if isinstance(self.vae.scaling_factor, torch.Tensor):
z = z * self.vae.scaling_factor.to(z.device, z.dtype)
else:
z = z * self.vae.scaling_factor
latent_model_input = latents.to(target_dtype).squeeze(0)
_, mask2 = masks_like([latent_model_input], zero=True)
latents = (1.0 - mask2[0]) * z + mask2[0] * latent_model_input
latents = latents.to(get_local_torch_device())
F = batch.num_frames
temporal_scale = (
server_args.pipeline_config.vae_config.arch_config.scale_factor_temporal
)
spatial_scale = (
server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
)
patch_size = server_args.pipeline_config.dit_config.arch_config.patch_size
seq_len = (
((F - 1) // temporal_scale + 1)
* (batch.height // spatial_scale)
* (batch.width // spatial_scale)
// (patch_size[1] * patch_size[2])
)
seq_len = (
int(math.ceil(seq_len / get_sp_world_size())) * get_sp_world_size()
)
guidance = self.get_or_build_guidance(
# TODO: replace with raw_latent_shape?
latents.shape[0],
latents.dtype,
latents.device,
)
image_kwargs = self.prepare_extra_func_kwargs(
self.transformer.forward,
{
# TODO: make sure on-device
"encoder_hidden_states_image": image_embeds,
"mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24),
},
)
pos_cond_kwargs = self.prepare_extra_func_kwargs(
self.transformer.forward,
{
"encoder_hidden_states_2": batch.clip_embedding_pos,
"encoder_attention_mask": batch.prompt_attention_mask,
}
| server_args.pipeline_config.prepare_pos_cond_kwargs(
batch,
self.device,
getattr(self.transformer, "rotary_emb", None),
dtype=target_dtype,
),
)
if batch.do_classifier_free_guidance:
neg_cond_kwargs = self.prepare_extra_func_kwargs(
self.transformer.forward,
{
"encoder_hidden_states_2": batch.clip_embedding_neg,
"encoder_attention_mask": batch.negative_attention_mask,
}
| server_args.pipeline_config.prepare_neg_cond_kwargs(
batch,
self.device,
getattr(self.transformer, "rotary_emb", None),
dtype=target_dtype,
),
)
else:
neg_cond_kwargs = {}
return {
"extra_step_kwargs": extra_step_kwargs,
"target_dtype": target_dtype,
"autocast_enabled": autocast_enabled,
"timesteps": timesteps,
"num_inference_steps": num_inference_steps,
"num_warmup_steps": num_warmup_steps,
"image_kwargs": image_kwargs,
"pos_cond_kwargs": pos_cond_kwargs,
"neg_cond_kwargs": neg_cond_kwargs,
"latents": latents,
"prompt_embeds": prompt_embeds,
"neg_prompt_embeds": neg_prompt_embeds,
"boundary_timestep": boundary_timestep,
"z": z,
"mask2": mask2,
"seq_len": seq_len,
"guidance": guidance,
}
def _post_denoising_loop(
self,
batch: Req,
latents: torch.Tensor,
trajectory_latents: list,
trajectory_timesteps: list,
server_args: ServerArgs,
):
# Gather results if using sequence parallelism
if trajectory_latents:
trajectory_tensor = torch.stack(trajectory_latents, dim=1)
trajectory_timesteps_tensor = torch.stack(trajectory_timesteps, dim=0)
else:
trajectory_tensor = None
trajectory_timesteps_tensor = None
# Gather results if using sequence parallelism
latents, trajectory_tensor = self._postprocess_sp_latents(
batch, latents, trajectory_tensor
)
if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:
batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()
batch.trajectory_latents = trajectory_tensor.cpu()
# Update batch with final latents
batch.latents = self.server_args.pipeline_config.post_denoising_loop(
latents, batch
)
# Save STA mask search results if needed
if (
st_attn_available
and self.attn_backend == SlidingTileAttentionBackend
and server_args.STA_mode == STA_Mode.STA_SEARCHING
):
self.save_sta_search_results(batch)
# deallocate transformer if on mps
pipeline = self.pipeline() if self.pipeline else None
if torch.backends.mps.is_available():
logger.info(
"Memory before deallocating transformer: %s",
torch.mps.current_allocated_memory(),
)
del self.transformer
if pipeline is not None and "transformer" in pipeline.modules:
del pipeline.modules["transformer"]
server_args.model_loaded["transformer"] = False
logger.info(
"Memory after deallocating transformer: %s",
torch.mps.current_allocated_memory(),
)
def _preprocess_sp_latents(self, batch: Req):
"""Shard latents for Sequence Parallelism if applicable."""
sp_world_size, rank_in_sp_group = get_sp_world_size(), get_sp_parallel_rank()
if get_sp_world_size() <= 1:
batch.did_sp_shard_latents = False
return
def _shard_tensor(
tensor: torch.Tensor | None,
) -> tuple[torch.Tensor | None, bool]:
if tensor is None:
return None, False
if tensor.dim() == 5:
time_dim = tensor.shape[2]
if time_dim > 0 and time_dim % sp_world_size == 0:
sharded_tensor = rearrange(
tensor, "b c (n t) h w -> b c n t h w", n=sp_world_size
).contiguous()
sharded_tensor = sharded_tensor[:, :, rank_in_sp_group, :, :, :]
return sharded_tensor, True
# For 4D image tensors or unsharded 5D tensors, return as is.
return tensor, False
batch.latents, did_shard = _shard_tensor(batch.latents)
batch.did_sp_shard_latents = did_shard
# image_latent is sharded independently, but the decision to all-gather later
# is based on whether the main `latents` was sharded.
if batch.image_latent is not None:
batch.image_latent, _ = _shard_tensor(batch.image_latent)
def _postprocess_sp_latents(
self,
batch: Req,
latents: torch.Tensor,
trajectory_tensor: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Gather latents after Sequence Parallelism if they were sharded."""
if get_sp_world_size() > 1 and getattr(batch, "did_sp_shard_latents", False):
latents = sequence_model_parallel_all_gather(latents, dim=2)
if trajectory_tensor is not None:
# trajectory_tensor shape: [b, num_steps, c, t_local, h, w] -> gather on dim 3
trajectory_tensor = trajectory_tensor.to(get_local_torch_device())
trajectory_tensor = sequence_model_parallel_all_gather(
trajectory_tensor, dim=3
)
return latents, trajectory_tensor
def start_profile(self, batch: Req):
if not batch.profile:
return
logger.info("Starting Profiler...")
# Build activities dynamically to avoid CUDA hangs when CUDA is unavailable
activities = [torch.profiler.ProfilerActivity.CPU]
if torch.cuda.is_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)
prof = torch.profiler.profile(
activities=activities,
schedule=torch.profiler.schedule(
skip_first=0,
wait=0,
warmup=5,
active=batch.num_profiled_timesteps,
repeat=5,
),
on_trace_ready=lambda _: torch.profiler.tensorboard_trace_handler(
f"./logs"
),
record_shapes=True,
with_stack=True,
)
prof.start()
self.profiler = prof
def step_profile(self):
if self.profiler:
if torch.cuda.is_available():
torch.cuda.synchronize()
self.profiler.step()
def stop_profile(self, batch: Req):
try:
if self.profiler:
logger.info("Stopping Profiler...")
if torch.cuda.is_available():
torch.cuda.synchronize()
self.profiler.stop()
request_id = batch.request_id if batch.request_id else "profile_trace"
log_dir = f"./logs"
os.makedirs(log_dir, exist_ok=True)
trace_path = os.path.abspath(
os.path.join(log_dir, f"{request_id}.trace.json.gz")
)
logger.info(f"Saving profiler traces to: {trace_path}")
self.profiler.export_chrome_trace(trace_path)
except Exception as e:
logger.error(f"{e}")
def _manage_device_placement(
self,
model_to_use: torch.nn.Module,
model_to_offload: torch.nn.Module | None,
server_args: ServerArgs,
):
"""
Manages the offload / load behavior of dit
"""
if not server_args.dit_cpu_offload:
return
# Offload the unused model if it's on CUDA
if (
model_to_offload is not None
and next(model_to_offload.parameters()).device.type == "cuda"
):
model_to_offload.to("cpu")
# Load the model to use if it's on CPU
if (
model_to_use is not None
and next(model_to_use.parameters()).device.type == "cpu"
):
model_to_use.to(get_local_torch_device())
def _select_and_manage_model(
self,
t_int: int,
boundary_timestep: float | None,
server_args: ServerArgs,
batch: Req,
):
if boundary_timestep is None or t_int >= boundary_timestep:
# High-noise stage
current_model = self.transformer
model_to_offload = self.transformer_2
current_guidance_scale = batch.guidance_scale
else:
# Low-noise stage
current_model = self.transformer_2
model_to_offload = self.transformer
current_guidance_scale = batch.guidance_scale_2
self._manage_device_placement(current_model, model_to_offload, server_args)
assert current_model is not None, "The model for the current step is not set."
return current_model, current_guidance_scale
@torch.no_grad()
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Run the denoising loop.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
The batch with denoised latents.
"""
# Prepare variables for the denoising loop
prepared_vars = self._prepare_denoising_loop(batch, server_args)
extra_step_kwargs = prepared_vars["extra_step_kwargs"]
target_dtype = prepared_vars["target_dtype"]
autocast_enabled = prepared_vars["autocast_enabled"]
timesteps = prepared_vars["timesteps"]
num_inference_steps = prepared_vars["num_inference_steps"]
num_warmup_steps = prepared_vars["num_warmup_steps"]
image_kwargs = prepared_vars["image_kwargs"]
pos_cond_kwargs = prepared_vars["pos_cond_kwargs"]
neg_cond_kwargs = prepared_vars["neg_cond_kwargs"]
latents = prepared_vars["latents"]
boundary_timestep = prepared_vars["boundary_timestep"]
z = prepared_vars["z"]
mask2 = prepared_vars["mask2"]
seq_len = prepared_vars["seq_len"]
guidance = prepared_vars["guidance"]
# Initialize lists for ODE trajectory
trajectory_timesteps: list[torch.Tensor] = []
trajectory_latents: list[torch.Tensor] = []
# Run denoising loop
denoising_start_time = time.time()
self.start_profile(batch=batch)
# to avoid device-sync caused by timestep comparison
timesteps_cpu = timesteps.cpu()
num_timesteps = timesteps_cpu.shape[0]
with torch.autocast(
device_type=("cuda" if torch.cuda.is_available() else "cpu"),
dtype=target_dtype,
enabled=autocast_enabled,
):
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t_host in enumerate(timesteps_cpu):
if batch.perf_logger:
batch.perf_logger.record_step_start()
# Skip if interrupted
if hasattr(self, "interrupt") and self.interrupt:
continue
t_int = int(t_host.item())
t_device = timesteps[i]
current_model, current_guidance_scale = (
self._select_and_manage_model(
t_int=t_int,
boundary_timestep=boundary_timestep,
server_args=server_args,
batch=batch,
)
)
# Expand latents for I2V
latent_model_input = latents.to(target_dtype)
if batch.image_latent is not None:
assert (
not server_args.pipeline_config.ti2v_task
), "image latents should not be provided for TI2V task"
latent_model_input = torch.cat(
[latent_model_input, batch.image_latent], dim=1
).to(target_dtype)
# expand timestep
if (
server_args.pipeline_config.ti2v_task
and batch.pil_image is not None
):
timestep = torch.stack([t_device]).to(get_local_torch_device())
temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
temp_ts = torch.cat(
[
temp_ts,
temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep,
]
)
timestep = temp_ts.unsqueeze(0)
t_expand = timestep.repeat(latent_model_input.shape[0], 1)
else:
t_expand = t_device.repeat(latent_model_input.shape[0])
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t_device
)
# Predict noise residual
attn_metadata = self._build_attn_metadata(i, batch, server_args)
noise_pred = self._predict_noise_with_cfg(
current_model,
latent_model_input,
t_expand,
batch,
i,
attn_metadata,
target_dtype,
current_guidance_scale,
image_kwargs,
pos_cond_kwargs,
neg_cond_kwargs,
server_args,
guidance=guidance,
latents=latents,
)
if batch.perf_logger:
batch.perf_logger.record_step_end("denoising_step_guided", i)
# Compute the previous noisy sample
latents = self.scheduler.step(
model_output=noise_pred,
timestep=t_device,
sample=latents,
**extra_step_kwargs,
return_dict=False,
)[0]
if (
server_args.pipeline_config.ti2v_task
and batch.pil_image is not None
):
latents = latents.squeeze(0)
latents = (1.0 - mask2[0]) * z + mask2[0] * latents
# save trajectory latents if needed
if batch.return_trajectory_latents:
trajectory_timesteps.append(t_host)
trajectory_latents.append(latents)
# Update progress bar
if i == num_timesteps - 1 or (
(i + 1) > num_warmup_steps
and (i + 1) % self.scheduler.order == 0
and progress_bar is not None
):
progress_bar.update()
self.step_profile()
self.stop_profile(batch)
denoising_end_time = time.time()
if num_timesteps > 0:
self.log_info(
"Average time per step: %.4f seconds",
(denoising_end_time - denoising_start_time) / len(timesteps),
)
self._post_denoising_loop(
batch=batch,
latents=latents,
trajectory_latents=trajectory_latents,
trajectory_timesteps=trajectory_timesteps,
server_args=server_args,
)
return batch
# TODO: this will extends the preparation stage, should let subclass/passed-in variables decide which to prepare
def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]:
"""
Prepare extra kwargs for the scheduler step / denoise step.
Args:
func: The function to prepare kwargs for.
kwargs: The kwargs to prepare.
Returns:
The prepared kwargs.
"""
extra_step_kwargs = {}
for k, v in kwargs.items():
accepts = k in set(inspect.signature(func).parameters.keys())
if accepts:
extra_step_kwargs[k] = v
return extra_step_kwargs
def progress_bar(
self, iterable: Iterable | None = None, total: int | None = None
) -> tqdm:
"""
Create a progress bar for the denoising process.
Args:
iterable: The iterable to iterate over.
total: The total number of items.
Returns:
A tqdm progress bar.
"""
local_rank = get_world_group().local_rank
if local_rank == 0:
return tqdm(iterable=iterable, total=total)
else:
return tqdm(iterable=iterable, total=total, disable=True)
def rescale_noise_cfg(
self, noise_cfg, noise_pred_text, guidance_rescale=0.0
) -> torch.Tensor:
"""
Rescale noise prediction according to guidance_rescale.
Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed"
(https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.
Args:
noise_cfg: The noise prediction with guidance.
noise_pred_text: The text-conditioned noise prediction.
guidance_rescale: The guidance rescale factor.
Returns:
The rescaled noise prediction.
"""
std_text = noise_pred_text.std(
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# Rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# Mix with the original results from guidance by factor guidance_rescale
noise_cfg = (
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
)
return noise_cfg
def _build_attn_metadata(
self, i: int, batch: Req, server_args: ServerArgs
) -> Any | None:
"""
Build attention metadata for custom attention backends.
Args:
i: The current timestep index.
batch: The current batch information.
server_args: The inference arguments.
Returns:
The attention metadata, or None if not applicable.
"""
attn_metadata = None
self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
if self.attn_metadata_builder_cls:
self.attn_metadata_builder = self.attn_metadata_builder_cls()
if (st_attn_available and self.attn_backend == SlidingTileAttentionBackend) or (
vsa_available and self.attn_backend == VideoSparseAttentionBackend
):
attn_metadata = self.attn_metadata_builder.build(
current_timestep=i,
raw_latent_shape=batch.raw_latent_shape[2:5],
patch_size=server_args.pipeline_config.dit_config.patch_size,
STA_param=batch.STA_param,
VSA_sparsity=server_args.VSA_sparsity,
device=get_local_torch_device(),
)
elif vmoba_attn_available and self.attn_backend == VMOBAAttentionBackend:
moba_params = server_args.moba_config.copy()
moba_params.update(
{
"current_timestep": i,
"raw_latent_shape": batch.raw_latent_shape[2:5],
"patch_size": server_args.pipeline_config.dit_config.patch_size,
"device": get_local_torch_device(),
}
)
elif self.attn_backend == FlashAttentionBackend:
attn_metadata = self.attn_metadata_builder.build(
raw_latent_shape=batch.raw_latent_shape
)
else:
return None
assert attn_metadata is not None, "attn_metadata cannot be None"
return attn_metadata
def _predict_noise(
self,
current_model,
latent_model_input,
t_expand,
prompt_embeds,
target_dtype,
guidance: torch.Tensor,
**kwargs,
):
return current_model(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=t_expand,
guidance=guidance,
**kwargs,
)
def _predict_noise_with_cfg(
self,
current_model: torch.nn.Module,
latent_model_input: torch.Tensor,
t_expand,
batch,
timestep_index: int,
attn_metadata,
target_dtype,
current_guidance_scale,
image_kwargs: dict[str, Any],
pos_cond_kwargs: dict[str, Any],
neg_cond_kwargs: dict[str, Any],
server_args,
guidance,
latents,
):
"""
Predict the noise residual with classifier-free guidance.
Args:
current_model: The transformer model to use for the current step.
latent_model_input: The input latents for the model.
t_expand: The expanded timestep tensor.
batch: The current batch information.
timestep_index: The current timestep index.
attn_metadata: Attention metadata for custom backends.
target_dtype: The target data type for autocasting.
current_guidance_scale: The guidance scale for the current step.
image_kwargs: Keyword arguments for image conditioning.
pos_cond_kwargs: Keyword arguments for positive prompt conditioning.
neg_cond_kwargs: Keyword arguments for negative prompt conditioning.
Returns:
The predicted noise.
"""
noise_pred_cond: torch.Tensor | None = None
noise_pred_uncond: torch.Tensor | None = None
cfg_rank = get_classifier_free_guidance_rank()
# positive pass
if not (server_args.enable_cfg_parallel and cfg_rank != 0):
batch.is_cfg_negative = False
with set_forward_context(
current_timestep=timestep_index,
attn_metadata=attn_metadata,
forward_batch=batch,
):
noise_pred_cond = self._predict_noise(
current_model=current_model,
latent_model_input=latent_model_input,
t_expand=t_expand,
prompt_embeds=server_args.pipeline_config.get_pos_prompt_embeds(
batch
),
target_dtype=target_dtype,
guidance=guidance,
**image_kwargs,
**pos_cond_kwargs,
)
# TODO: can it be moved to after _predict_noise_with_cfg?
noise_pred_cond = server_args.pipeline_config.slice_noise_pred(
noise_pred_cond, latents
)
if not batch.do_classifier_free_guidance:
# If CFG is disabled, we are done. Return the conditional prediction.
return noise_pred_cond
# negative pass
if not server_args.enable_cfg_parallel or cfg_rank != 0:
batch.is_cfg_negative = True
with set_forward_context(
current_timestep=timestep_index,
attn_metadata=attn_metadata,
forward_batch=batch,
):
noise_pred_uncond = self._predict_noise(
current_model=current_model,
latent_model_input=latent_model_input,
t_expand=t_expand,
prompt_embeds=server_args.pipeline_config.get_neg_prompt_embeds(
batch
),
target_dtype=target_dtype,
guidance=guidance,
**image_kwargs,
**neg_cond_kwargs,
)
noise_pred_uncond = server_args.pipeline_config.slice_noise_pred(
noise_pred_uncond, latents
)
# Combine predictions
if server_args.enable_cfg_parallel:
# Each rank computes its partial contribution and we sum via all-reduce:
# final = s*cond + (1-s)*uncond
if cfg_rank == 0:
assert noise_pred_cond is not None
partial = current_guidance_scale * noise_pred_cond
else:
assert noise_pred_uncond is not None
partial = (1 - current_guidance_scale) * noise_pred_uncond
noise_pred = cfg_model_parallel_all_reduce(partial)
# Guidance rescale: broadcast std(cond) from rank 0, compute std(cfg) locally
if batch.guidance_rescale > 0.0:
std_cfg = noise_pred.std(
dim=list(range(1, noise_pred.ndim)), keepdim=True
)
if cfg_rank == 0:
assert noise_pred_cond is not None
std_text = noise_pred_cond.std(
dim=list(range(1, noise_pred_cond.ndim)), keepdim=True
)
else:
std_text = torch.empty_like(std_cfg)
# Broadcast std_text from local src=0 to all ranks in CFG group
std_text = get_cfg_group().broadcast(std_text, src=0)
noise_pred_rescaled = noise_pred * (std_text / std_cfg)
noise_pred = (
batch.guidance_rescale * noise_pred_rescaled
+ (1 - batch.guidance_rescale) * noise_pred
)
return noise_pred
else:
# Serial CFG: both cond and uncond are available locally
assert noise_pred_cond is not None and noise_pred_uncond is not None
noise_pred = noise_pred_uncond + current_guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
if batch.guidance_rescale > 0.0:
noise_pred = self.rescale_noise_cfg(
noise_pred,
noise_pred_cond,
guidance_rescale=batch.guidance_rescale,
)
return noise_pred
def prepare_sta_param(self, batch: Req, server_args: ServerArgs):
"""
Prepare Sliding Tile Attention (STA) parameters and settings.
Args:
batch: The current batch information.
server_args: The inference arguments.
"""
# TODO(kevin): STA mask search, currently only support Wan2.1 with 69x768x1280
STA_mode = server_args.STA_mode
skip_time_steps = server_args.skip_time_steps
if batch.timesteps is None:
raise ValueError("Timesteps must be provided")
timesteps_num = batch.timesteps.shape[0]
logger.info("STA_mode: %s", STA_mode)
if (batch.num_frames, batch.height, batch.width) != (
69,
768,
1280,
) and STA_mode != "STA_inference":
raise NotImplementedError(
"STA mask search/tuning is not supported for this resolution"
)
if (
STA_mode == STA_Mode.STA_SEARCHING
or STA_mode == STA_Mode.STA_TUNING
or STA_mode == STA_Mode.STA_TUNING_CFG
):
size = (batch.width, batch.height)
if size == (1280, 768):
# TODO: make it configurable
sparse_mask_candidates_searching = [
"3, 1, 10",
"1, 5, 7",
"3, 3, 3",
"1, 6, 5",
"1, 3, 10",
"3, 6, 1",
]
sparse_mask_candidates_tuning = [
"3, 1, 10",
"1, 5, 7",
"3, 3, 3",
"1, 6, 5",
"1, 3, 10",
"3, 6, 1",
]
full_mask = ["3,6,10"]
else:
raise NotImplementedError(
"STA mask search is not supported for this resolution"
)
layer_num = self.transformer.config.num_layers
# specific for HunyuanVideo
if hasattr(self.transformer.config, "num_single_layers"):
layer_num += self.transformer.config.num_single_layers
head_num = self.transformer.config.num_attention_heads
if STA_mode == STA_Mode.STA_SEARCHING:
STA_param = configure_sta(
mode=STA_Mode.STA_SEARCHING,
layer_num=layer_num,
head_num=head_num,
time_step_num=timesteps_num,
mask_candidates=sparse_mask_candidates_searching + full_mask,
# last is full mask; Can add more sparse masks while keep last one as full mask
)
elif STA_mode == STA_Mode.STA_TUNING:
STA_param = configure_sta(
mode=STA_Mode.STA_TUNING,
layer_num=layer_num,
head_num=head_num,
time_step_num=timesteps_num,
mask_search_files_path=f"output/mask_search_result_pos_{size[0]}x{size[1]}/",
mask_candidates=sparse_mask_candidates_tuning,
full_attention_mask=[int(x) for x in full_mask[0].split(",")],
skip_time_steps=skip_time_steps, # Use full attention for first 12 steps
save_dir=f"output/mask_search_strategy_{size[0]}x{size[1]}/", # Custom save directory
timesteps=timesteps_num,
)
elif STA_mode == STA_Mode.STA_TUNING_CFG:
STA_param = configure_sta(
mode=STA_Mode.STA_TUNING_CFG,
layer_num=layer_num,
head_num=head_num,
time_step_num=timesteps_num,
mask_search_files_path_pos=f"output/mask_search_result_pos_{size[0]}x{size[1]}/",
mask_search_files_path_neg=f"output/mask_search_result_neg_{size[0]}x{size[1]}/",
mask_candidates=sparse_mask_candidates_tuning,
full_attention_mask=[int(x) for x in full_mask[0].split(",")],
skip_time_steps=skip_time_steps,
save_dir=f"output/mask_search_strategy_{size[0]}x{size[1]}/",
timesteps=timesteps_num,
)
elif STA_mode == STA_Mode.STA_INFERENCE:
import sglang.multimodal_gen.envs as envs
config_file = envs.SGL_DIFFUSION_ATTENTION_CONFIG
if config_file is None:
raise ValueError("SGL_DIFFUSION_ATTENTION_CONFIG is not set")
STA_param = configure_sta(
mode=STA_Mode.STA_INFERENCE,
layer_num=layer_num,
head_num=head_num,
time_step_num=timesteps_num,
load_path=config_file,
)
batch.STA_param = STA_param
batch.mask_search_final_result_pos = [[] for _ in range(timesteps_num)]
batch.mask_search_final_result_neg = [[] for _ in range(timesteps_num)]
def save_sta_search_results(self, batch: Req):
"""
Save the STA mask search results.
Args:
batch: The current batch information.
"""
size = (batch.width, batch.height)
if size == (1280, 768):
# TODO: make it configurable
sparse_mask_candidates_searching = [
"3, 1, 10",
"1, 5, 7",
"3, 3, 3",
"1, 6, 5",
"1, 3, 10",
"3, 6, 1",
]
else:
raise NotImplementedError(
"STA mask search is not supported for this resolution"
)
if batch.mask_search_final_result_pos is not None and batch.prompt is not None:
save_mask_search_results(
[dict(layer_data) for layer_data in batch.mask_search_final_result_pos],
prompt=str(batch.prompt),
mask_strategies=sparse_mask_candidates_searching,
output_dir=f"output/mask_search_result_pos_{size[0]}x{size[1]}/",
)
if batch.mask_search_final_result_neg is not None and batch.prompt is not None:
save_mask_search_results(
[dict(layer_data) for layer_data in batch.mask_search_final_result_neg],
prompt=str(batch.prompt),
mask_strategies=sparse_mask_candidates_searching,
output_dir=f"output/mask_search_result_neg_{size[0]}x{size[1]}/",
)
def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify denoising stage inputs."""
result = VerificationResult()
result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.min_dims(1)])
# disable temporarily for image-generation models
# result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
result.add_check("image_embeds", batch.image_embeds, V.is_list)
# result.add_check(
# "image_latent", batch.image_latent, V.none_or_tensor_with_dims(5)
# )
result.add_check(
"num_inference_steps", batch.num_inference_steps, V.positive_int
)
result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
result.add_check("eta", batch.eta, V.non_negative_float)
result.add_check("generator", batch.generator, V.generator_or_list_generators)
result.add_check(
"do_classifier_free_guidance",
batch.do_classifier_free_guidance,
V.bool_value,
)
result.add_check(
"negative_prompt_embeds",
batch.negative_prompt_embeds,
lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x),
)
return result
def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify denoising stage outputs."""
result = VerificationResult()
# result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
return result
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
import time
import torch
from einops import rearrange
from sglang.multimodal_gen.runtime.distributed import (
get_local_torch_device,
get_sp_parallel_rank,
get_sp_world_size,
logger,
sequence_model_parallel_all_gather,
)
from sglang.multimodal_gen.runtime.layers.attention.backends.sliding_tile_attn import (
SlidingTileAttentionBackend,
)
from sglang.multimodal_gen.runtime.layers.attention.backends.video_sparse_attn import (
VideoSparseAttentionBackend,
)
from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context
from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_match_euler_discrete import (
FlowMatchEulerDiscreteScheduler,
)
from sglang.multimodal_gen.runtime.models.utils import pred_noise_to_pred_video
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages import DenoisingStage
from sglang.multimodal_gen.runtime.pipelines.stages.denoising import (
st_attn_available,
vsa_available,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.utils import dict_to_3d_list
# TODO: use base methods of DenoisingStage
class DmdDenoisingStage(DenoisingStage):
"""
Denoising stage for DMD.
"""
def __init__(self, transformer, scheduler) -> None:
super().__init__(transformer, scheduler)
self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0)
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Run the denoising loop.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
The batch with denoised latents.
"""
# Setup precision and autocast settings
# TODO(will): make the precision configurable for inference
# target_dtype = PRECISION_TO_TYPE[server_args.precision]
target_dtype = torch.bfloat16
autocast_enabled = (
target_dtype != torch.float32
) and not server_args.disable_autocast
# Get timesteps and calculate warmup steps
timesteps = batch.timesteps
# TODO(will): remove this once we add input/output validation for stages
if timesteps is None:
raise ValueError("Timesteps must be provided")
num_inference_steps = batch.num_inference_steps
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# Prepare image latents and embeddings for I2V generation
image_embeds = batch.image_embeds
if len(image_embeds) > 0:
assert torch.isnan(image_embeds[0]).sum() == 0
image_embeds = [
image_embed.to(target_dtype) for image_embed in image_embeds
]
image_kwargs = self.prepare_extra_func_kwargs(
self.transformer.forward,
{
"encoder_hidden_states_image": image_embeds,
"mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24),
},
)
pos_cond_kwargs = self.prepare_extra_func_kwargs(
self.transformer.forward,
{
"encoder_hidden_states_2": batch.clip_embedding_pos,
"encoder_attention_mask": batch.prompt_attention_mask,
},
)
# Prepare STA parameters
if st_attn_available and self.attn_backend == SlidingTileAttentionBackend:
self.prepare_sta_param(batch, server_args)
# Get latents and embeddings
assert batch.latents is not None, "latents must be provided"
latents = batch.latents
latents = latents.permute(0, 2, 1, 3, 4)
video_raw_latent_shape = latents.shape
prompt_embeds = batch.prompt_embeds
assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan"
timesteps = torch.tensor(
server_args.pipeline_config.dmd_denoising_steps,
dtype=torch.long,
device=get_local_torch_device(),
)
# Handle sequence parallelism if enabled
sp_world_size, rank_in_sp_group = (
get_sp_world_size(),
get_sp_parallel_rank(),
)
sp_group = sp_world_size > 1
if sp_group:
latents = rearrange(
latents, "b (n t) c h w -> b n t c h w", n=sp_world_size
).contiguous()
latents = latents[:, rank_in_sp_group, :, :, :, :]
if batch.image_latent is not None:
image_latent = rearrange(
batch.image_latent,
"b c (n t) h w -> b c n t h w",
n=sp_world_size,
).contiguous()
image_latent = image_latent[:, :, rank_in_sp_group, :, :, :]
batch.image_latent = image_latent
# Run denoising loop
denoising_loop_start_time = time.time()
with self.progress_bar(total=len(timesteps)) as progress_bar:
for i, t in enumerate(timesteps):
# Skip if interrupted
if hasattr(self, "interrupt") and self.interrupt:
continue
# Expand latents for I2V
noise_latents = latents.clone()
latent_model_input = latents.to(target_dtype)
if batch.image_latent is not None:
latent_model_input = torch.cat(
[
latent_model_input,
batch.image_latent.permute(0, 2, 1, 3, 4),
],
dim=2,
).to(target_dtype)
assert not torch.isnan(
latent_model_input
).any(), "latent_model_input contains nan"
# Prepare inputs for transformer
t_expand = t.repeat(latent_model_input.shape[0])
guidance_expand = (
torch.tensor(
[server_args.pipeline_config.embedded_cfg_scale]
* latent_model_input.shape[0],
dtype=torch.float32,
device=get_local_torch_device(),
).to(target_dtype)
* 1000.0
if server_args.pipeline_config.embedded_cfg_scale is not None
else None
)
# Predict noise residual
with torch.autocast(
device_type="cuda",
dtype=target_dtype,
enabled=autocast_enabled,
):
if (
vsa_available
and self.attn_backend == VideoSparseAttentionBackend
):
self.attn_metadata_builder_cls = (
self.attn_backend.get_builder_cls()
)
if self.attn_metadata_builder_cls is not None:
self.attn_metadata_builder = (
self.attn_metadata_builder_cls()
)
# TODO(will): clean this up
attn_metadata = self.attn_metadata_builder.build( # type: ignore
current_timestep=i, # type: ignore
raw_latent_shape=batch.raw_latent_shape[2:5], # type: ignore
patch_size=server_args.pipeline_config.dit_config.patch_size, # type: ignore
STA_param=batch.STA_param, # type: ignore
VSA_sparsity=server_args.VSA_sparsity, # type: ignore
device=get_local_torch_device(), # type: ignore
) # type: ignore
assert (
attn_metadata is not None
), "attn_metadata cannot be None"
else:
attn_metadata = None
else:
attn_metadata = None
batch.is_cfg_negative = False
with set_forward_context(
current_timestep=i,
attn_metadata=attn_metadata,
forward_batch=batch,
# server_args=server_args
):
# Run transformer
pred_noise = self.transformer(
latent_model_input.permute(0, 2, 1, 3, 4),
prompt_embeds,
t_expand,
guidance=guidance_expand,
**image_kwargs,
**pos_cond_kwargs,
).permute(0, 2, 1, 3, 4)
pred_video = pred_noise_to_pred_video(
pred_noise=pred_noise.flatten(0, 1),
noise_input_latent=noise_latents.flatten(0, 1),
timestep=t_expand,
scheduler=self.scheduler,
).unflatten(0, pred_noise.shape[:2])
if i < len(timesteps) - 1:
next_timestep = timesteps[i + 1] * torch.ones(
[1], dtype=torch.long, device=pred_video.device
)
noise = torch.randn(
video_raw_latent_shape,
dtype=pred_video.dtype,
generator=batch.generator[0],
).to(self.device)
if sp_group:
noise = rearrange(
noise,
"b (n t) c h w -> b n t c h w",
n=sp_world_size,
).contiguous()
noise = noise[:, rank_in_sp_group, :, :, :, :]
latents = self.scheduler.add_noise(
pred_video.flatten(0, 1),
noise.flatten(0, 1),
next_timestep,
).unflatten(0, pred_video.shape[:2])
else:
latents = pred_video
# Update progress bar
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps
and (i + 1) % self.scheduler.order == 0
and progress_bar is not None
):
progress_bar.update()
denoising_loop_end_time = time.time()
if len(timesteps) > 0:
logger.info(
"Average time per step: %.4f seconds",
(denoising_loop_end_time - denoising_loop_start_time) / len(timesteps),
)
# Gather results if using sequence parallelism
if sp_group:
latents = sequence_model_parallel_all_gather(latents, dim=1)
latents = latents.permute(0, 2, 1, 3, 4)
# Update batch with final latents
batch.latents = latents
return batch
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Encoding stage for diffusion pipelines.
"""
import torch
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage
from sglang.multimodal_gen.runtime.pipelines.stages.validators import (
V, # Import validators
)
from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import PRECISION_TO_TYPE
logger = init_logger(__name__)
class EncodingStage(PipelineStage):
"""
Stage for encoding pixel space representations into latent space.
This stage handles the encoding of pixel-space video/images into latent
representations for further processing in the diffusion pipeline.
"""
def __init__(self, vae: ParallelTiledVAE) -> None:
self.vae: ParallelTiledVAE = vae
@torch.no_grad()
def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify encoding stage inputs."""
result = VerificationResult()
# Input video/images for VAE encoding: [batch_size, channels, frames, height, width]
result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
return result
def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify encoding stage outputs."""
result = VerificationResult()
# Encoded latents: [batch_size, channels, frames, height_latents, width_latents]
result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
return result
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Encode pixel space representations into latent space.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
The batch with encoded latents.
"""
assert batch.latents is not None and isinstance(batch.latents, torch.Tensor)
self.vae = self.vae.to(get_local_torch_device())
# Setup VAE precision
vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]
vae_autocast_enabled = (
vae_dtype != torch.float32
) and not server_args.disable_autocast
# Normalize input to [-1, 1] range (reverse of decoding normalization)
latents = (batch.latents * 2.0 - 1.0).clamp(-1, 1)
# Move to appropriate device and dtype
latents = latents.to(get_local_torch_device())
# Encode image to latents
with torch.autocast(
device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
):
if server_args.pipeline_config.vae_tiling:
self.vae.enable_tiling()
# if server_args.vae_sp:
# self.vae.enable_parallel()
if not vae_autocast_enabled:
latents = latents.to(vae_dtype)
latents = self.vae.encode(latents).mean
# Update batch with encoded latents
batch.latents = latents
# Offload models if needed
if hasattr(self, "maybe_free_model_hooks"):
self.maybe_free_model_hooks()
if server_args.vae_cpu_offload:
self.vae.to("cpu")
return batch
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Image encoding stages for I2V diffusion pipelines.
This module contains implementations of image encoding stages for diffusion pipelines.
"""
import PIL
import torch
from sglang.multimodal_gen.configs.pipelines.qwen_image import (
QwenImageEditPipelineConfig,
QwenImagePipelineConfig,
_pack_latents,
qwen_image_postprocess_text,
)
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context
from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE
from sglang.multimodal_gen.runtime.models.vision_utils import (
normalize,
numpy_to_pt,
pil_to_numpy,
resize,
)
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage
from sglang.multimodal_gen.runtime.pipelines.stages.validators import (
StageValidators as V,
)
from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult
from sglang.multimodal_gen.runtime.server_args import ExecutionMode, ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import PRECISION_TO_TYPE
logger = init_logger(__name__)
class ImageEncodingStage(PipelineStage):
"""
Stage for encoding image prompts into embeddings for diffusion models.
This stage handles the encoding of image prompts into the embedding space
expected by the diffusion model.
"""
def __init__(
self,
image_processor,
image_encoder=None,
text_encoder=None,
vae_image_processor=None,
) -> None:
"""
Initialize the prompt encoding stage.
Args:
text_encoder: An encoder to encode input_ids and pixel values
"""
super().__init__()
self.image_processor = image_processor
self.vae_image_processor = vae_image_processor
self.image_encoder = image_encoder
self.text_encoder = text_encoder
def move_to_device(self, device):
fields = [
"image_processor",
"image_encoder",
]
for field in fields:
processor = getattr(self, field, None)
if processor and hasattr(processor, "to"):
setattr(self, field, processor.to(device))
def encoding_qwen_image_edit(self, outputs, image_inputs):
# encoder hidden state
prompt_embeds = qwen_image_postprocess_text(outputs, image_inputs, 64)
return prompt_embeds
@torch.inference_mode()
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Encode the prompt into image encoder hidden states.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
The batch with encoded prompt embeddings.
"""
cuda_device = get_local_torch_device()
self.move_to_device(cuda_device)
image = batch.pil_image
# preprocess the imag_processor
prompt_image = server_args.pipeline_config.preprocess_image(
image, self.vae_image_processor
)
if batch.prompt and (
isinstance(server_args.pipeline_config, QwenImageEditPipelineConfig)
or isinstance(server_args.pipeline_config, QwenImagePipelineConfig)
):
prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
txt = prompt_template_encode.format(batch.prompt)
image_processor_kwargs = dict(text=[txt], padding=True)
else:
image_processor_kwargs = {}
image_inputs = self.image_processor(
images=prompt_image, return_tensors="pt", **image_processor_kwargs
).to(cuda_device)
if self.image_encoder:
# if an image encoder is provided
with set_forward_context(current_timestep=0, attn_metadata=None):
outputs = self.image_encoder(
**image_inputs,
**server_args.pipeline_config.image_encoder_extra_args,
)
image_embeds = server_args.pipeline_config.postprocess_image(outputs)
batch.image_embeds.append(image_embeds)
elif self.text_encoder:
# if a text encoder is provided, e.g. Qwen-Image-Edit
# 1. neg prompt embeds
if batch.prompt:
prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
txt = prompt_template_encode.format(batch.negative_prompt)
neg_image_processor_kwargs = dict(text=[txt], padding=True)
else:
neg_image_processor_kwargs = {}
neg_image_inputs = self.image_processor(
images=prompt_image, return_tensors="pt", **neg_image_processor_kwargs
).to(get_local_torch_device())
with set_forward_context(current_timestep=0, attn_metadata=None):
outputs = self.text_encoder(
input_ids=image_inputs.input_ids,
attention_mask=image_inputs.attention_mask,
pixel_values=image_inputs.pixel_values,
image_grid_thw=image_inputs.image_grid_thw,
output_hidden_states=True,
)
neg_outputs = self.text_encoder(
input_ids=neg_image_inputs.input_ids,
attention_mask=neg_image_inputs.attention_mask,
pixel_values=neg_image_inputs.pixel_values,
image_grid_thw=neg_image_inputs.image_grid_thw,
output_hidden_states=True,
)
batch.prompt_embeds.append(
self.encoding_qwen_image_edit(outputs, image_inputs)
)
batch.negative_prompt_embeds.append(
self.encoding_qwen_image_edit(neg_outputs, neg_image_inputs)
)
self.move_to_device("cpu")
return batch
def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify image encoding stage inputs."""
result = VerificationResult()
if batch.debug:
logger.debug(f"{batch.pil_image=}")
logger.debug(f"{batch.image_embeds=}")
result.add_check("pil_image", batch.pil_image, V.not_none)
result.add_check("image_embeds", batch.image_embeds, V.is_list)
return result
def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify image encoding stage outputs."""
result = VerificationResult()
# result.add_check("image_embeds", batch.image_embeds, V.list_of_tensors_dims(3))
return result
class ImageVAEEncodingStage(PipelineStage):
"""
Stage for encoding pixel representations into latent space.
This stage handles the encoding of pixel representations into the final
input format (e.g., latents).
"""
def __init__(self, vae: ParallelTiledVAE, **kwargs) -> None:
super().__init__()
self.vae: ParallelTiledVAE = vae
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Encode pixel representations into latent space.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
The batch with encoded outputs.
"""
assert batch.pil_image is not None
if server_args.mode == ExecutionMode.INFERENCE:
assert batch.pil_image is not None and isinstance(
batch.pil_image, PIL.Image.Image
)
assert batch.height is not None and isinstance(batch.height, int)
assert batch.width is not None and isinstance(batch.width, int)
assert batch.num_frames is not None and isinstance(batch.num_frames, int)
height = batch.height
width = batch.width
num_frames = batch.num_frames
elif server_args.mode == ExecutionMode.PREPROCESS:
assert batch.pil_image is not None and isinstance(
batch.pil_image, torch.Tensor
)
assert batch.height is not None and isinstance(batch.height, list)
assert batch.width is not None and isinstance(batch.width, list)
assert batch.num_frames is not None and isinstance(batch.num_frames, list)
num_frames = batch.num_frames[0]
height = batch.height[0]
width = batch.width[0]
self.vae = self.vae.to(get_local_torch_device())
latent_height = height // self.vae.spatial_compression_ratio
latent_width = width // self.vae.spatial_compression_ratio
image = batch.pil_image
image = self.preprocess(
image,
vae_scale_factor=self.vae.spatial_compression_ratio,
height=height,
width=width,
).to(get_local_torch_device(), dtype=torch.float32)
# (B, C, H, W) -> (B, C, 1, H, W)
image = image.unsqueeze(2)
video_condition = torch.cat(
[
image,
image.new_zeros(
image.shape[0],
image.shape[1],
num_frames - 1,
image.shape[3],
image.shape[4],
),
],
dim=2,
)
video_condition = video_condition.to(
device=get_local_torch_device(), dtype=torch.float32
)
# Setup VAE precision
vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision]
vae_autocast_enabled = (
vae_dtype != torch.float32
) and not server_args.disable_autocast
# Encode Image
with torch.autocast(
device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
):
if server_args.pipeline_config.vae_tiling:
self.vae.enable_tiling()
# if server_args.vae_sp:
# self.vae.enable_parallel()
if not vae_autocast_enabled:
video_condition = video_condition.to(vae_dtype)
encoder_output = self.vae.encode(video_condition)
if server_args.mode == ExecutionMode.PREPROCESS:
latent_condition = encoder_output.mean
else:
generator = batch.generator
if generator is None:
raise ValueError("Generator must be provided")
latent_condition = self.retrieve_latents(encoder_output, generator)
# Apply shifting if needed
if hasattr(self.vae, "shift_factor") and self.vae.shift_factor is not None:
if isinstance(self.vae.shift_factor, torch.Tensor):
latent_condition -= self.vae.shift_factor.to(
latent_condition.device, latent_condition.dtype
)
else:
latent_condition -= self.vae.shift_factor
if isinstance(self.vae.scaling_factor, torch.Tensor):
latent_condition = latent_condition * self.vae.scaling_factor.to(
latent_condition.device, latent_condition.dtype
)
else:
latent_condition = latent_condition * self.vae.scaling_factor
if server_args.mode == ExecutionMode.PREPROCESS:
batch.image_latent = latent_condition
else:
if isinstance(server_args.pipeline_config, QwenImageEditPipelineConfig):
batch_size = batch.batch_size
if (
batch_size > latent_condition.shape[0]
and batch_size % latent_condition.shape[0] == 0
):
# expand init_latents for batch_size
additional_image_per_prompt = (
batch_size // latent_condition.shape[0]
)
image_latents = torch.cat(
[latent_condition] * additional_image_per_prompt, dim=0
)
elif (
batch_size > latent_condition.shape[0]
and batch_size % latent_condition.shape[0] != 0
):
raise ValueError(
f"Cannot duplicate `image` of batch size {latent_condition.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([latent_condition], dim=0)
image_latent_height, image_latent_width = image_latents.shape[3:]
num_channels_latents = (
self.server_args.pipeline_config.dit_config.arch_config.in_channels
// 4
)
image_latents = _pack_latents(
image_latents,
batch_size,
num_channels_latents,
image_latent_height,
image_latent_width,
)
else:
mask_lat_size = torch.ones(
1, 1, num_frames, latent_height, latent_width
)
mask_lat_size[:, :, list(range(1, num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask,
repeats=self.vae.temporal_compression_ratio,
dim=2,
)
mask_lat_size = torch.concat(
[first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2
)
mask_lat_size = mask_lat_size.view(
1,
-1,
self.vae.temporal_compression_ratio,
latent_height,
latent_width,
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(latent_condition.device)
image_latents = torch.concat([mask_lat_size, latent_condition], dim=1)
batch.image_latent = image_latents
# Offload models if needed
if hasattr(self, "maybe_free_model_hooks"):
self.maybe_free_model_hooks()
self.vae.to("cpu")
return batch
def retrieve_latents(
self,
encoder_output: torch.Tensor,
generator: torch.Generator | None = None,
sample_mode: str = "sample",
):
if sample_mode == "sample":
return encoder_output.sample(generator)
elif sample_mode == "argmax":
return encoder_output.mode()
else:
raise AttributeError("Could not access latents of provided encoder_output")
def preprocess(
self,
image: torch.Tensor | PIL.Image.Image,
vae_scale_factor: int,
height: int | None = None,
width: int | None = None,
resize_mode: str = "default", # "default", "fill", "crop"
) -> torch.Tensor:
if isinstance(image, PIL.Image.Image):
width, height = (
self.server_args.pipeline_config.vae_config.calculate_dimensions(
image, vae_scale_factor, width, height
)
)
image = resize(image, height, width, resize_mode=resize_mode)
image = pil_to_numpy(image) # to np
image = numpy_to_pt(image) # to pt
do_normalize = True
if image.min() < 0:
do_normalize = False
if do_normalize:
image = normalize(image)
return image
def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify encoding stage inputs."""
result = VerificationResult()
result.add_check("generator", batch.generator, V.generator_or_list_generators)
if server_args.mode == ExecutionMode.PREPROCESS:
result.add_check("height", batch.height, V.list_not_empty)
result.add_check("width", batch.width, V.list_not_empty)
result.add_check("num_frames", batch.num_frames, V.list_not_empty)
else:
result.add_check("height", batch.height, V.positive_int)
result.add_check("width", batch.width, V.positive_int)
result.add_check("num_frames", batch.num_frames, V.positive_int)
return result
def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify encoding stage outputs."""
result = VerificationResult()
# result.add_check(
# "image_latent", batch.image_latent, [V.is_tensor, V.with_dims(5)]
# )
return result
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Input validation stage for diffusion pipelines.
"""
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from sglang.multimodal_gen.configs.pipelines import WanI2V480PConfig
from sglang.multimodal_gen.configs.pipelines.qwen_image import (
QwenImageEditPipelineConfig,
)
from sglang.multimodal_gen.runtime.models.vision_utils import load_image, load_video
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage
from sglang.multimodal_gen.runtime.pipelines.stages.validators import (
StageValidators,
VerificationResult,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.utils import best_output_size
logger = init_logger(__name__)
# Alias for convenience
V = StageValidators
# TODO: since this might change sampling params after logging, should be do this beforehand?
class InputValidationStage(PipelineStage):
"""
Stage for validating and preparing inputs for diffusion pipelines.
This stage validates that all required inputs are present and properly formatted
before proceeding with the diffusion process.
In this stage, input image and output image may be resized
"""
def _generate_seeds(self, batch: Req, server_args: ServerArgs):
"""Generate seeds for the inference"""
seed = batch.seed
num_videos_per_prompt = batch.num_outputs_per_prompt
assert seed is not None
seeds = [seed + i for i in range(num_videos_per_prompt)]
batch.seeds = seeds
# Peiyuan: using GPU seed will cause A100 and H100 to generate different results...
# FIXME: the generator's in latent preparation stage seems to be different from seeds
batch.generator = [torch.Generator("cpu").manual_seed(seed) for seed in seeds]
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Validate and prepare inputs.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
The validated batch information.
"""
self._generate_seeds(batch, server_args)
# Ensure prompt is properly formatted
if batch.prompt is None and batch.prompt_embeds is None:
raise ValueError("Either `prompt` or `prompt_embeds` must be provided")
# Ensure negative prompt is properly formatted if using classifier-free guidance
if (
batch.do_classifier_free_guidance
and batch.negative_prompt is None
and batch.negative_prompt_embeds is None
):
raise ValueError(
"For classifier-free guidance, either `negative_prompt` or "
"`negative_prompt_embeds` must be provided"
)
# Validate height and width
if batch.height is None or batch.width is None:
raise ValueError(
"Height and width must be provided. Please set `height` and `width`."
)
if batch.height % 8 != 0 or batch.width % 8 != 0:
raise ValueError(
f"Height and width must be divisible by 8 but are {batch.height} and {batch.width}."
)
# Validate number of inference steps
if batch.num_inference_steps <= 0:
raise ValueError(
f"Number of inference steps must be positive, but got {batch.num_inference_steps}"
)
# Validate guidance scale if using classifier-free guidance
if batch.do_classifier_free_guidance and batch.guidance_scale <= 0:
raise ValueError(
f"Guidance scale must be positive, but got {batch.guidance_scale}"
)
# for i2v, get image from image_path
# @TODO(Wei) hard-coded for wan2.2 5b ti2v for now. Should put this in image_encoding stage
if batch.image_path is not None:
if batch.image_path.endswith(".mp4"):
image = load_video(batch.image_path)[0]
else:
image = load_image(batch.image_path)
batch.pil_image = image
# NOTE: resizing needs to be bring in advance
if isinstance(server_args.pipeline_config, QwenImageEditPipelineConfig):
height = None if batch.height_not_provided else batch.height
width = None if batch.width_not_provided else batch.width
width, height = server_args.pipeline_config.set_width_and_height(
height, width, batch.pil_image
)
batch.width = width
batch.height = height
elif (
server_args.pipeline_config.ti2v_task
or server_args.pipeline_config.ti2i_task
) and batch.pil_image is not None:
# further processing for ti2v task
img = batch.pil_image
ih, iw = img.height, img.width
patch_size = server_args.pipeline_config.dit_config.arch_config.patch_size
vae_stride = (
server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
)
dh, dw = patch_size[1] * vae_stride, patch_size[2] * vae_stride
max_area = 704 * 1280
ow, oh = best_output_size(iw, ih, dw, dh, max_area)
scale = max(ow / iw, oh / ih)
img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)
logger.info("resized img height: %s, img width: %s", img.height, img.width)
# center-crop
x1 = (img.width - ow) // 2
y1 = (img.height - oh) // 2
img = img.crop((x1, y1, x1 + ow, y1 + oh))
assert img.width == ow and img.height == oh
# to tensor
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1)
img = img.unsqueeze(0)
batch.height = oh
batch.width = ow
# TODO: should we store in a new field: pixel values?
batch.pil_image = img
if isinstance(server_args.pipeline_config, WanI2V480PConfig):
# TODO: could we merge with above?
# resize image only, Wan2.1 I2V
max_area = 720 * 1280
aspect_ratio = image.height / image.width
mod_value = (
server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
* server_args.pipeline_config.dit_config.arch_config.patch_size[1]
)
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
batch.pil_image = batch.pil_image.resize((width, height))
batch.height = height
batch.width = width
return batch
def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify input validation stage inputs."""
result = VerificationResult()
result.add_check("seed", batch.seed, [V.not_none, V.non_negative_int])
result.add_check(
"num_videos_per_prompt", batch.num_outputs_per_prompt, V.positive_int
)
result.add_check(
"prompt_or_embeds",
None,
lambda _: V.string_or_list_strings(batch.prompt)
or V.list_not_empty(batch.prompt_embeds),
)
result.add_check("height", batch.height, V.positive_int)
result.add_check("width", batch.width, V.positive_int)
result.add_check(
"num_inference_steps", batch.num_inference_steps, V.positive_int
)
result.add_check(
"guidance_scale",
batch.guidance_scale,
lambda x: not batch.do_classifier_free_guidance or V.positive_float(x),
)
return result
def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify input validation stage outputs."""
result = VerificationResult()
result.add_check("seeds", batch.seeds, V.list_not_empty)
result.add_check("generator", batch.generator, V.generator_or_list_generators)
return result
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Latent preparation stage for diffusion pipelines.
"""
from diffusers.utils.torch_utils import randn_tensor
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage
from sglang.multimodal_gen.runtime.pipelines.stages.validators import (
StageValidators as V,
)
from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class LatentPreparationStage(PipelineStage):
"""
Stage for preparing initial latent variables for the diffusion process.
This stage handles the preparation of the initial latent variables that will be
denoised during the diffusion process.
"""
def __init__(self, scheduler, transformer) -> None:
super().__init__()
self.scheduler = scheduler
self.transformer = transformer
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Prepare initial latent variables for the diffusion process.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
The batch with prepared latent variables.
"""
latent_num_frames = None
# Adjust video length based on VAE version if needed
if hasattr(self, "adjust_video_length"):
latent_num_frames = self.adjust_video_length(batch, server_args)
batch_size = batch.batch_size
# Get required parameters
dtype = batch.prompt_embeds[0].dtype
device = get_local_torch_device()
generator = batch.generator
latents = batch.latents
num_frames = (
latent_num_frames if latent_num_frames is not None else batch.num_frames
)
height = batch.height
width = batch.width
# TODO(will): remove this once we add input/output validation for stages
if height is None or width is None:
raise ValueError("Height and width must be provided")
# Validate generator if it's a list
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# Generate or use provided latents
if latents is None:
shape = server_args.pipeline_config.prepare_latent_shape(
batch, batch_size, num_frames
)
latents = randn_tensor(
shape, generator=generator, device=device, dtype=dtype
)
latents = server_args.pipeline_config.pack_latents(
latents, batch_size, batch
)
else:
latents = latents.to(device)
# Scale the initial noise if needed
if hasattr(self.scheduler, "init_noise_sigma"):
latents = latents * self.scheduler.init_noise_sigma
# Update batch with prepared latents
batch.latents = latents
batch.raw_latent_shape = latents.shape
return batch
def adjust_video_length(self, batch: Req, server_args: ServerArgs) -> int:
"""
Adjust video length based on VAE version.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
The batch with adjusted video length.
"""
video_length = batch.num_frames
use_temporal_scaling_frames = (
server_args.pipeline_config.vae_config.use_temporal_scaling_frames
)
if use_temporal_scaling_frames:
temporal_scale_factor = (
server_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
)
latent_num_frames = (video_length - 1) // temporal_scale_factor + 1
else: # stepvideo only
latent_num_frames = video_length // 17 * 3
return int(latent_num_frames)
def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify latent preparation stage inputs."""
result = VerificationResult()
result.add_check(
"prompt_or_embeds",
None,
lambda _: V.string_or_list_strings(batch.prompt)
or V.list_not_empty(batch.prompt_embeds),
)
result.add_check("prompt_embeds", batch.prompt_embeds, V.list_of_tensors)
result.add_check(
"num_videos_per_prompt", batch.num_outputs_per_prompt, V.positive_int
)
result.add_check("generator", batch.generator, V.generator_or_list_generators)
result.add_check("num_frames", batch.num_frames, V.positive_int)
result.add_check("height", batch.height, V.positive_int)
result.add_check("width", batch.width, V.positive_int)
result.add_check("latents", batch.latents, V.none_or_tensor)
return result
def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify latent preparation stage outputs."""
result = VerificationResult()
if batch.debug:
logger.debug(f"{batch.raw_latent_shape=}")
# disable temporarily for image-generation models
# result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple)
return result
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import torch
from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage
from sglang.multimodal_gen.runtime.pipelines.stages.validators import (
StageValidators as V,
)
from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
# The dedicated stepvideo prompt encoding stage.
class StepvideoPromptEncodingStage(PipelineStage):
"""
Stage for encoding prompts using the remote caption API.
This stage applies the magic string transformations and calls
the remote caption service asynchronously to get:
- primary prompt embeddings,
- an attention mask,
- and a clip embedding.
"""
def __init__(self, stepllm, clip) -> None:
super().__init__()
# self.caption_client = caption_client # This should have a call_caption(prompts: List[str]) method.
self.stepllm = stepllm
self.clip = clip
@torch.no_grad()
def forward(self, batch: Req, server_args) -> Req:
prompts = [batch.prompt + server_args.pipeline_config.pos_magic]
bs = len(prompts)
prompts += [server_args.pipeline_config.neg_magic] * bs
with set_forward_context(current_timestep=0, attn_metadata=None):
y, y_mask = self.stepllm(prompts)
clip_emb, _ = self.clip(prompts)
len_clip = clip_emb.shape[1]
y_mask = torch.nn.functional.pad(y_mask, (len_clip, 0), value=1)
pos_clip, neg_clip = clip_emb[:bs], clip_emb[bs:]
# split positive vs negative text
batch.prompt_embeds = y[:bs] # [bs, seq_len, dim]
batch.negative_prompt_embeds = y[bs : 2 * bs] # [bs, seq_len, dim]
batch.prompt_attention_mask = y_mask[:bs] # [bs, seq_len]
batch.negative_attention_mask = y_mask[bs : 2 * bs] # [bs, seq_len]
batch.clip_embedding_pos = pos_clip
batch.clip_embedding_neg = neg_clip
return batch
def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify stepvideo encoding stage inputs."""
result = VerificationResult()
result.add_check("prompt", batch.prompt, V.string_not_empty)
return result
def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify stepvideo encoding stage outputs."""
result = VerificationResult()
result.add_check(
"prompt_embeds", batch.prompt_embeds, [V.is_tensor, V.with_dims(3)]
)
result.add_check(
"negative_prompt_embeds",
batch.negative_prompt_embeds,
[V.is_tensor, V.with_dims(3)],
)
result.add_check(
"prompt_attention_mask",
batch.prompt_attention_mask,
[V.is_tensor, V.with_dims(2)],
)
result.add_check(
"negative_attention_mask",
batch.negative_attention_mask,
[V.is_tensor, V.with_dims(2)],
)
result.add_check(
"clip_embedding_pos",
batch.clip_embedding_pos,
[V.is_tensor, V.with_dims(2)],
)
result.add_check(
"clip_embedding_neg",
batch.clip_embedding_neg,
[V.is_tensor, V.with_dims(2)],
)
return result
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Prompt encoding stages for diffusion pipelines.
This module contains implementations of prompt encoding stages for diffusion pipelines.
"""
import torch
from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput
from sglang.multimodal_gen.configs.pipelines import FluxPipelineConfig
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage
from sglang.multimodal_gen.runtime.pipelines.stages.validators import (
StageValidators as V,
)
from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class TextEncodingStage(PipelineStage):
"""
Stage for encoding text prompts into embeddings for diffusion models.
This stage handles the encoding of text prompts into the embedding space
expected by the diffusion model.
"""
def __init__(self, text_encoders, tokenizers) -> None:
"""
Initialize the prompt encoding stage.
"""
super().__init__()
self.tokenizers = tokenizers
self.text_encoders = text_encoders
@torch.no_grad()
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Encode the prompt into text encoder hidden states.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
The batch with encoded prompt embeddings.
"""
assert len(self.tokenizers) == len(self.text_encoders)
assert len(self.text_encoders) == len(
server_args.pipeline_config.text_encoder_configs
)
# Encode positive prompt with all available encoders
assert batch.prompt is not None
prompt_text: str | list[str] = batch.prompt
all_indices: list[int] = list(range(len(self.text_encoders)))
prompt_embeds_list, prompt_masks_list, pooler_embeds_list = self.encode_text(
prompt_text,
server_args,
encoder_index=all_indices,
return_attention_mask=True,
)
for pe in prompt_embeds_list:
batch.prompt_embeds.append(pe)
for pe in pooler_embeds_list:
batch.pooled_embeds.append(pe)
if batch.prompt_attention_mask is not None:
for am in prompt_masks_list:
batch.prompt_attention_mask.append(am)
# Encode negative prompt if CFG is enabled
if batch.do_classifier_free_guidance:
assert isinstance(batch.negative_prompt, str)
neg_embeds_list, neg_masks_list, neg_pooler_embeds_list = self.encode_text(
batch.negative_prompt,
server_args,
encoder_index=all_indices,
return_attention_mask=True,
)
assert batch.negative_prompt_embeds is not None
for ne in neg_embeds_list:
batch.negative_prompt_embeds.append(ne)
for pe in neg_pooler_embeds_list:
batch.neg_pooled_embeds.append(pe)
if batch.negative_attention_mask is not None:
for nm in neg_masks_list:
batch.negative_attention_mask.append(nm)
return batch
def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify text encoding stage inputs."""
result = VerificationResult()
result.add_check("prompt", batch.prompt, V.string_or_list_strings)
result.add_check(
"negative_prompt",
batch.negative_prompt,
lambda x: not batch.do_classifier_free_guidance or V.string_not_none(x),
)
result.add_check(
"do_classifier_free_guidance",
batch.do_classifier_free_guidance,
V.bool_value,
)
result.add_check("prompt_embeds", batch.prompt_embeds, V.is_list)
result.add_check(
"negative_prompt_embeds", batch.negative_prompt_embeds, V.none_or_list
)
return result
def prepare_tokenizer_kwargs(self, tokenizer_kwargs, **kwargs):
tok_kwargs = tokenizer_kwargs | kwargs
return tok_kwargs
@torch.no_grad()
def encode_text(
self,
text: str | list[str],
server_args: ServerArgs,
encoder_index: int | list[int] | None = None,
return_attention_mask: bool = False,
return_type: str = "list", # one of: "list", "dict", "stack"
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
max_length: int | None = None,
truncation: bool | None = None,
padding: bool | str | None = None,
return_overflowing_tokens=None,
return_length=None,
):
"""
Encode plain text using selected text encoder(s) and return embeddings.
Args:
text: A single string or a list of strings to encode.
server_args: The inference arguments providing pipeline config,
including tokenizer and encoder settings, preprocess and postprocess
functions.
encoder_index: Encoder selector by index. Accepts an int or list of ints.
return_attention_mask: If True, also return attention masks for each
selected encoder.
return_type: "list" (default) returns a list aligned with selection;
"dict" returns a dict keyed by encoder index as a string; "stack" stacks along a
new first dimension (requires matching shapes).
device: Optional device override for inputs; defaults to local torch device.
dtype: Optional dtype to cast returned embeddings to.
max_length: Optional per-call tokenizer override.
truncation: Optional per-call tokenizer override.
padding: Optional per-call tokenizer override.
Returns:
Depending on return_type and return_attention_mask:
- list: List[Tensor] or (List[Tensor], List[Tensor])
- dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor])
- stack: Tensor of shape [num_encoders, ...] or a tuple with stacked
attention masks
"""
assert len(self.tokenizers) == len(self.text_encoders)
assert len(self.text_encoders) == len(
server_args.pipeline_config.text_encoder_configs
)
# Resolve selection into indices
encoder_cfgs = server_args.pipeline_config.text_encoder_configs
if encoder_index is None:
indices: list[int] = [0]
elif isinstance(encoder_index, int):
indices = [encoder_index]
else:
indices = list(encoder_index)
# validate range
num_encoders = len(self.text_encoders)
for idx in indices:
if idx < 0 or idx >= num_encoders:
raise IndexError(
f"encoder index {idx} out of range [0, {num_encoders - 1}]"
)
# Validate indices are within range
num_encoders = len(self.text_encoders)
# Normalize input to list[str]
assert isinstance(text, str | list)
if isinstance(text, str):
texts: list[str] = [text]
else:
texts = text
embeds_list: list[torch.Tensor] = []
pooled_embeds_list: list[torch.Tensor] = []
attn_masks_list: list[torch.Tensor] = []
preprocess_funcs = server_args.pipeline_config.preprocess_text_funcs
postprocess_funcs = server_args.pipeline_config.postprocess_text_funcs
text_encoder_extra_args = server_args.pipeline_config.text_encoder_extra_args
encoder_cfgs = server_args.pipeline_config.text_encoder_configs
if return_type not in ("list", "dict", "stack"):
raise ValueError(
f"Invalid return_type '{return_type}'. Expected one of: 'list', 'dict', 'stack'"
)
target_device = device if device is not None else get_local_torch_device()
for i in indices:
tokenizer = self.tokenizers[i]
text_encoder = self.text_encoders[i]
encoder_config = encoder_cfgs[i]
preprocess_func = preprocess_funcs[i]
postprocess_func = postprocess_funcs[i]
text_encoder_extra_arg = (
text_encoder_extra_args[i]
if i < len(text_encoder_extra_args) and text_encoder_extra_args[i]
else {}
)
processed_texts: list[str] = []
for prompt_str in texts:
processed_texts.append(preprocess_func(prompt_str))
# Prepare tokenizer args
tok_kwargs = self.prepare_tokenizer_kwargs(
encoder_config.tokenizer_kwargs,
**text_encoder_extra_arg,
)
text_inputs = tokenizer(processed_texts, **tok_kwargs).to(target_device)
input_ids = text_inputs["input_ids"]
is_flux = isinstance(server_args.pipeline_config, FluxPipelineConfig)
is_flux_t5 = is_flux and i == 1
if is_flux_t5:
attention_mask = torch.ones(input_ids.shape[:2], device=target_device)
else:
attention_mask = text_inputs["attention_mask"]
with set_forward_context(current_timestep=0, attn_metadata=None):
outputs: BaseEncoderOutput = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
prompt_embeds = postprocess_func(outputs, text_inputs)
if dtype is not None:
prompt_embeds = prompt_embeds.to(dtype=dtype)
embeds_list.append(prompt_embeds)
if is_flux:
pooled_embeds_list.append(outputs.pooler_output)
if return_attention_mask:
attn_masks_list.append(attention_mask)
# Shape results according to return_type
if return_type == "list":
if return_attention_mask:
return embeds_list, attn_masks_list, pooled_embeds_list
return embeds_list, pooled_embeds_list
if return_type == "dict":
key_strs = [str(i) for i in indices]
embeds_dict = {k: v for k, v in zip(key_strs, embeds_list, strict=False)}
if return_attention_mask:
attn_dict = {
k: v for k, v in zip(key_strs, attn_masks_list, strict=False)
}
return embeds_dict, attn_dict
return embeds_dict
# return_type == "stack"
# Validate shapes are compatible
base_shape = list(embeds_list[0].shape)
for t in embeds_list[1:]:
if list(t.shape) != base_shape:
raise ValueError(
f"Cannot stack embeddings with differing shapes: {[list(t.shape) for t in embeds_list]}"
)
stacked_embeds = torch.stack(embeds_list, dim=0)
if return_attention_mask:
base_mask_shape = list(attn_masks_list[0].shape)
for m in attn_masks_list[1:]:
if list(m.shape) != base_mask_shape:
raise ValueError(
f"Cannot stack attention masks with differing shapes: {[list(m.shape) for m in attn_masks_list]}"
)
stacked_masks = torch.stack(attn_masks_list, dim=0)
return stacked_embeds, stacked_masks
return stacked_embeds
def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify text encoding stage outputs."""
result = VerificationResult()
result.add_check(
"prompt_embeds", batch.prompt_embeds, V.list_of_tensors_min_dims(2)
)
result.add_check(
"negative_prompt_embeds",
batch.negative_prompt_embeds,
lambda x: not batch.do_classifier_free_guidance
or V.list_of_tensors_with_min_dims(x, 2),
)
if batch.debug:
logger.debug(f"{batch.prompt_embeds=}")
logger.debug(f"{batch.negative_prompt_embeds=}")
return result
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Timestep preparation stages for diffusion pipelines.
This module contains implementations of timestep preparation stages for diffusion pipelines.
"""
import inspect
from typing import Any, Callable, Tuple
import numpy as np
from sglang.multimodal_gen.configs.pipelines import FluxPipelineConfig
from sglang.multimodal_gen.configs.pipelines.qwen_image import (
QwenImageEditPipelineConfig,
QwenImagePipelineConfig,
)
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages.base import (
PipelineStage,
StageParallelismType,
)
from sglang.multimodal_gen.runtime.pipelines.stages.validators import (
StageValidators as V,
)
from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class TimestepPreparationStage(PipelineStage):
"""
Stage for preparing timesteps for the diffusion process.
This stage handles the preparation of the timestep sequence that will be used
during the diffusion process.
"""
def __init__(
self,
scheduler,
prepare_extra_set_timesteps_kwargs: list[
Callable[[Req, ServerArgs], Tuple[str, Any]]
] = [],
) -> None:
self.scheduler = scheduler
self.prepare_extra_set_timesteps_kwargs = prepare_extra_set_timesteps_kwargs
@property
def parallelism_type(self) -> StageParallelismType:
return StageParallelismType.REPLICATED
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Prepare timesteps for the diffusion process.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
The batch with prepared timesteps.
"""
scheduler = self.scheduler
device = get_local_torch_device()
num_inference_steps = batch.num_inference_steps
timesteps = batch.timesteps
sigmas = batch.sigmas
n_tokens = batch.n_tokens
is_flux = (
isinstance(server_args.pipeline_config, FluxPipelineConfig)
or isinstance(server_args.pipeline_config, QwenImagePipelineConfig)
or isinstance(server_args.pipeline_config, QwenImageEditPipelineConfig)
)
if is_flux:
sigmas = (
np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
if sigmas is None
else sigmas
)
# Prepare extra kwargs for set_timesteps
extra_set_timesteps_kwargs = {}
if (
n_tokens is not None
and "n_tokens" in inspect.signature(scheduler.set_timesteps).parameters
):
extra_set_timesteps_kwargs["n_tokens"] = n_tokens
for callee in self.prepare_extra_set_timesteps_kwargs:
key, value = callee(batch, server_args)
assert isinstance(key, str)
extra_set_timesteps_kwargs[key] = value
# Handle custom timesteps or sigmas
if timesteps is not None and sigmas is not None:
raise ValueError(
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
)
if timesteps is not None:
accepts_timesteps = (
"timesteps" in inspect.signature(scheduler.set_timesteps).parameters
)
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(
timesteps=timesteps, device=device, **extra_set_timesteps_kwargs
)
timesteps = scheduler.timesteps
elif sigmas is not None:
accept_sigmas = (
"sigmas" in inspect.signature(scheduler.set_timesteps).parameters
)
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(
sigmas=sigmas, device=device, **extra_set_timesteps_kwargs
)
timesteps = scheduler.timesteps
else:
scheduler.set_timesteps(
num_inference_steps, device=device, **extra_set_timesteps_kwargs
)
timesteps = scheduler.timesteps
# Update batch with prepared timesteps
batch.timesteps = timesteps
self.log_debug(f"timesteps: {timesteps}")
return batch
def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify timestep preparation stage inputs."""
result = VerificationResult()
result.add_check(
"num_inference_steps", batch.num_inference_steps, V.positive_int
)
result.add_check("timesteps", batch.timesteps, V.none_or_tensor)
result.add_check("sigmas", batch.sigmas, V.none_or_list)
result.add_check("n_tokens", batch.n_tokens, V.none_or_positive_int)
return result
def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""Verify timestep preparation stage outputs."""
result = VerificationResult()
result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.with_dims(1)])
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment