Unverified Commit 7bc1dae0 authored by Mick's avatar Mick Committed by GitHub
Browse files

WIP: initial multimodal-gen support (#12484)


Co-authored-by: default avataryhyang201 <yhyang201@gmail.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
Co-authored-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarJiLi <leege233@gmail.com>
Co-authored-by: default avatarCHEN Xi <78632976+RubiaCx@users.noreply.github.com>
Co-authored-by: default avatarlaixin <xielx@shanghaitech.edu.cn>
Co-authored-by: default avatarSolitaryThinker <wlsaidhi@gmail.com>
Co-authored-by: default avatarjzhang38 <a1286225768@gmail.com>
Co-authored-by: default avatarBrianChen1129 <yongqichcd@gmail.com>
Co-authored-by: default avatarKevin Lin <42618777+kevin314@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarrlsu9 <r3su@ucsd.edu>
Co-authored-by: default avatarJinzhe Pan <48981407+eigensystem@users.noreply.github.com>
Co-authored-by: default avatarforeverpiano <pianoqwz@qq.com>
Co-authored-by: default avatarRandNMR73 <notomatthew31@gmail.com>
Co-authored-by: default avatarPorridgeSwim <yz3883@columbia.edu>
Co-authored-by: default avatarJiali Chen <90408393+gary-chenjl@users.noreply.github.com>
parent 4fe53e58
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: check https://huggingface.co/papers/2302.04867 and https://github.com/wl-zhao/UniPC for more info
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
# ==============================================================================
#
# Modified from diffusers==0.35.0.dev0
#
# ==============================================================================
import math
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput,
)
from diffusers.utils import deprecate, is_scipy_available
from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler
if is_scipy_available():
import scipy.stats
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin, BaseScheduler):
"""
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
solver_order (`int`, default `2`):
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
unconditional sampling.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
predict_x0 (`bool`, defaults to `True`):
Whether to use the updating algorithm on the predicted x0.
solver_type (`str`, default `bh2`):
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
otherwise.
lower_order_final (`bool`, default `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
disable_corrector (`list`, default `[]`):
Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
usually disabled during the first few steps.
solver_p (`SchedulerMixin`, default `None`):
Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: np.ndarray | list[float] | None = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
predict_x0: bool = True,
solver_type: str = "bh2",
lower_order_final: bool = True,
disable_corrector: list[int] = [],
solver_p: SchedulerMixin = None,
use_karras_sigmas: bool | None = False,
use_exponential_sigmas: bool | None = False,
use_beta_sigmas: bool | None = False,
use_flow_sigmas: bool | None = False,
flow_shift: float | None = 1.0,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
final_sigmas_type: str | None = "zero", # "zero", "sigma_min"
rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError(
"Make sure to install scipy if you want to use beta sigmas."
)
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(
beta_start, beta_end, num_train_timesteps, dtype=torch.float32
)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(
f"{beta_schedule} is not implemented for {self.__class__}"
)
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
if rescale_betas_zero_snr:
# Close to 0 without being 0 so first sigma is not inf
# FP16 smallest positive subnormal works well here
self.alphas_cumprod[-1] = 2**-24
# Currently we only support VP-type noise schedule
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
if solver_type not in ["bh1", "bh2"]:
if solver_type in ["midpoint", "heun", "logrho"]:
self.register_to_config(solver_type="bh2")
else:
raise NotImplementedError(
f"{solver_type} is not implemented for {self.__class__}"
)
self.predict_x0 = predict_x0
# setable values
self.num_inference_steps = None
timesteps = np.linspace(
0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32
)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.num_train_timesteps = num_train_timesteps
self.model_outputs = [None] * solver_order
self.timestep_list = [None] * solver_order
self.lower_order_nums = 0
self.disable_corrector = disable_corrector
self.solver_p = solver_p
self.last_sample = None
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
BaseScheduler.__init__(self)
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
def set_shift(self, shift: float) -> None:
self.config.flow_shift = shift
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(
self,
num_inference_steps: int,
device: str | torch.device = None,
mu: float | None = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if mu is not None:
assert (
self.config.use_dynamic_shifting
and self.config.time_shift_type == "exponential"
)
self.config.flow_shift = np.exp(mu)
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(
0, self.config.num_train_timesteps - 1, num_inference_steps + 1
)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(np.arange(0, num_inference_steps + 1) * step_ratio)
.round()[::-1][:-1]
.copy()
.astype(np.int64)
)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
np.arange(self.config.num_train_timesteps, 0, -step_ratio)
.round()
.copy()
.astype(np.int64)
)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas:
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(
in_sigmas=sigmas, num_inference_steps=num_inference_steps
)
timesteps = np.array(
[self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
).round()
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
elif self.config.use_exponential_sigmas:
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_exponential(
in_sigmas=sigmas, num_inference_steps=num_inference_steps
)
timesteps = np.array(
[self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
)
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
elif self.config.use_beta_sigmas:
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_beta(
in_sigmas=sigmas, num_inference_steps=num_inference_steps
)
timesteps = np.array(
[self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]
)
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
elif self.config.use_flow_sigmas:
alphas = np.linspace(
1, 1 / self.config.num_train_timesteps, num_inference_steps + 1
)
sigmas = 1.0 - alphas
sigmas = np.flip(
self.config.flow_shift
* sigmas
/ (1 + (self.config.flow_shift - 1) * sigmas)
)[:-1].copy()
timesteps = (sigmas * self.config.num_train_timesteps).copy()
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min":
sigma_last = (
(1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]
) ** 0.5
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(
device=device, dtype=torch.int64
)
self.num_inference_steps = len(timesteps)
self.model_outputs = [
None,
] * self.config.solver_order
self.lower_order_nums = 0
self.last_sample = None
if self.solver_p:
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
# add an index counter for schedulers that allow duplicated timesteps
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
https://huggingface.co/papers/2205.11487
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = (
sample.float()
) # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = (
torch.clamp(sample, -s, s) / s
) # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = (
np.cumsum((dists >= 0), axis=0)
.argmax(axis=0)
.clip(max=log_sigmas.shape[0] - 2)
)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
else:
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(
self, in_sigmas: torch.Tensor, num_inference_steps
) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(
self, in_sigmas: torch.Tensor, num_inference_steps: int
) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = np.exp(
np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)
)
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self,
in_sigmas: torch.Tensor,
num_inference_steps: int,
alpha: float = 0.6,
beta: float = 0.6,
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
def convert_model_output(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
r"""
Convert the model output to the corresponding type the UniPC algorithm needs.
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyword argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
if self.predict_x0:
if self.config.prediction_type == "epsilon":
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
x0_pred = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
"`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
)
if self.config.thresholding:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
else:
if self.config.prediction_type == "epsilon":
return model_output
elif self.config.prediction_type == "sample":
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction":
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the UniPCMultistepScheduler."
)
def multistep_uni_p_bh_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
order: int = None,
**kwargs,
) -> torch.Tensor:
"""
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model at the current timestep.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
order (`int`):
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyword argument")
if order is None:
if len(args) > 2:
order = args[2]
else:
raise ValueError("missing `order` as a required keyword argument")
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
model_output_list = self.model_outputs
s0 = self.timestep_list[-1]
m0 = model_output_list[-1]
x = sample
if self.solver_p:
x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t
sigma_t, sigma_s0 = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0
device = sample.device
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - i
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
else:
D1s = None
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
x_t = x_t.to(x.dtype)
return x_t
def multistep_uni_c_bh_update(
self,
this_model_output: torch.Tensor,
*args,
last_sample: torch.Tensor = None,
this_sample: torch.Tensor = None,
order: int = None,
**kwargs,
) -> torch.Tensor:
"""
One step for the UniC (B(h) version).
Args:
this_model_output (`torch.Tensor`):
The model outputs at `x_t`.
this_timestep (`int`):
The current timestep `t`.
last_sample (`torch.Tensor`):
The generated sample before the last predictor `x_{t-1}`.
this_sample (`torch.Tensor`):
The generated sample after the last predictor `x_{t}`.
order (`int`):
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
Returns:
`torch.Tensor`:
The corrected sample tensor at the current timestep.
"""
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
if last_sample is None:
if len(args) > 1:
last_sample = args[1]
else:
raise ValueError("missing `last_sample` as a required keyword argument")
if this_sample is None:
if len(args) > 2:
this_sample = args[2]
else:
raise ValueError("missing `this_sample` as a required keyword argument")
if order is None:
if len(args) > 3:
order = args[3]
else:
raise ValueError("missing `order` as a required keyword argument")
if this_timestep is not None:
deprecate(
"this_timestep",
"1.0.0",
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
model_output_list = self.model_outputs
m0 = model_output_list[-1]
x = last_sample
x_t = this_sample
model_t = this_model_output
sigma_t, sigma_s0 = (
self.sigmas[self.step_index],
self.sigmas[self.step_index - 1],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0
device = this_sample.device
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - (i + 1)
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1)
else:
D1s = None
# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
x_t = x_t.to(x.dtype)
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
"""
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.Tensor,
timestep: int | torch.Tensor,
sample: torch.Tensor,
return_dict: bool = True,
) -> SchedulerOutput | tuple:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
the multistep UniPC.
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to call 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
use_corrector = (
self.step_index > 0
and self.step_index - 1 not in self.disable_corrector
and self.last_sample is not None
)
model_output_convert = self.convert_model_output(model_output, sample=sample)
if use_corrector:
sample = self.multistep_uni_c_bh_update(
this_model_output=model_output_convert,
last_sample=self.last_sample,
this_sample=sample,
order=self.this_order,
)
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.timestep_list[i] = self.timestep_list[i + 1]
self.model_outputs[-1] = model_output_convert
self.timestep_list[-1] = timestep
if self.config.lower_order_final:
this_order = min(
self.config.solver_order, len(self.timesteps) - self.step_index
)
else:
this_order = self.config.solver_order
self.this_order = min(
this_order, self.lower_order_nums + 1
) # warmup for multistep
assert self.this_order > 0
self.last_sample = sample
prev_sample = self.multistep_uni_p_bh_update(
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
sample=sample,
order=self.this_order,
)
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
return sample
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(
device=original_samples.device, dtype=original_samples.dtype
)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(
original_samples.device, dtype=torch.float32
)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [
self.index_for_timestep(t, schedule_timesteps) for t in timesteps
]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
EntryClass = UniPCMultistepScheduler
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/utils.py
"""Utils for model executor."""
from typing import Any
import torch
# TODO(PY): move it elsewhere
def auto_attributes(init_func):
"""
Decorator that automatically adds all initialization arguments as object attributes.
Example:
@auto_attributes
def __init__(self, a=1, b=2):
pass
# This will automatically set:
# - self.a = 1 and self.b = 2
# - self.config.a = 1 and self.config.b = 2
"""
def wrapper(self, *args, **kwargs):
# Get the function signature
import inspect
signature = inspect.signature(init_func)
parameters = signature.parameters
# Get parameter names (excluding 'self')
param_names = list(parameters.keys())[1:]
# Bind arguments to parameters
bound_args = signature.bind(self, *args, **kwargs)
bound_args.apply_defaults()
# Create config object if it doesn't exist
if not hasattr(self, "config"):
self.config = type("Config", (), {})()
# Set attributes on self and self.config
for name in param_names:
if name in bound_args.arguments:
value = bound_args.arguments[name]
setattr(self, name, value)
setattr(self.config, name, value)
# Call the original __init__ function
return init_func(self, *args, **kwargs)
return wrapper
def set_weight_attrs(
weight: torch.Tensor,
weight_attrs: dict[str, Any] | None,
):
"""Set attributes on a weight tensor.
This method is used to set attributes on a weight tensor. This method
will not overwrite existing attributes.
Args:
weight: The weight tensor.
weight_attrs: A dictionary of attributes to set on the weight tensor.
"""
if weight_attrs is None:
return
for key, value in weight_attrs.items():
assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
# NOTE(woosuk): During weight loading, we often do something like:
# narrowed_tensor = param.data.narrow(0, offset, len)
# narrowed_tensor.copy_(real_weight)
# expecting narrowed_tensor and param.data to share the same storage.
# However, on TPUs, narrowed_tensor will lazily propagate to the base
# tensor, which is param.data, leading to the redundant memory usage.
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
# TODO(woosuk): Remove this hack once we have a better solution.
from sglang.multimodal_gen.runtime.platforms import current_platform
if current_platform.is_tpu() and key == "weight_loader":
value = _make_synced_weight_loader(value)
setattr(weight, key, value)
def _make_synced_weight_loader(original_weight_loader) -> Any:
def _synced_weight_loader(param, *args, **kwargs):
original_weight_loader(param, *args, **kwargs)
torch._sync(param)
return _synced_weight_loader
def extract_layer_index(layer_name: str) -> int:
"""
Extract the layer index from the module name.
Examples:
- "encoder.layers.0" -> 0
- "encoder.layers.1.self_attn" -> 1
- "2.self_attn" -> 2
- "model.encoder.layers.0.sub.1" -> ValueError
"""
subnames = layer_name.split(".")
int_vals: list[int] = []
for subname in subnames:
try:
int_vals.append(int(subname))
except ValueError:
continue
assert len(int_vals) == 1, (
f"layer name {layer_name} should" " only contain one integer"
)
return int_vals[0]
def modulate(
x: torch.Tensor,
shift: torch.Tensor | None = None,
scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""modulate by shift and scale
Args:
x (torch.Tensor): input tensor.
shift (torch.Tensor, optional): shift tensor. Defaults to None.
scale (torch.Tensor, optional): scale tensor. Defaults to None.
Returns:
torch.Tensor: the output tensor after modulate.
"""
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1)) # type: ignore[union-attr]
elif scale is None:
return x + shift.unsqueeze(1) # type: ignore[union-attr]
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(
1
) # type: ignore[union-attr]
def pred_noise_to_pred_video(
pred_noise: torch.Tensor,
noise_input_latent: torch.Tensor,
timestep: torch.Tensor,
scheduler: Any,
) -> torch.Tensor:
"""
Convert predicted noise to clean latent.
Args:
pred_noise: the predicted noise with shape [B, C, H, W]
where B is batch_size or batch_size * num_frames
noise_input_latent: the noisy latent with shape [B, C, H, W],
timestep: the timestep with shape [1] or [bs * num_frames] or [bs, num_frames]
scheduler: the scheduler
Returns:
the predicted video with shape [B, C, H, W]
"""
# If timestep is [bs, num_frames]
if timestep.ndim == 2:
timestep = timestep.flatten(0, 1)
assert timestep.numel() == noise_input_latent.shape[0]
elif timestep.ndim == 1:
# If timestep is [1]
if timestep.shape[0] == 1:
timestep = timestep.expand(noise_input_latent.shape[0])
else:
assert timestep.numel() == noise_input_latent.shape[0]
else:
raise ValueError(
f"[pred_noise_to_pred_video] Invalid timestep shape: {timestep.shape}"
)
# timestep shape should be [B]
dtype = pred_noise.dtype
device = pred_noise.device
pred_noise = pred_noise.double().to(device)
noise_input_latent = noise_input_latent.double().to(device)
sigmas = scheduler.sigmas.double().to(device)
timesteps = scheduler.timesteps.double().to(device)
timestep_id = torch.argmin(
(timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1
)
sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1)
pred_video = noise_input_latent - sigma_t * pred_noise
return pred_video.to(dtype)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from typing import Dict, Optional, Tuple, Union
import torch
from diffusers.models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
)
from diffusers.models.autoencoders.vae import (
Decoder,
DecoderOutput,
DiagonalGaussianDistribution,
Encoder,
)
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from torch import nn
from sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig
class AutoencoderKL(nn.Module):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
mid_block_add_attention (`bool`, *optional*, default to `True`):
If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
mid_block will only have resnet blocks
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
def __init__(
self,
config: FluxVAEConfig,
):
super().__init__()
self.config = config
arch_config = config.arch_config
in_channels = arch_config.in_channels
out_channels = arch_config.out_channels
down_block_types = arch_config.down_block_types
up_block_types = arch_config.up_block_types
block_out_channels = arch_config.block_out_channels
layers_per_block = arch_config.layers_per_block
act_fn = arch_config.act_fn
latent_channels = arch_config.latent_channels
norm_num_groups = arch_config.norm_num_groups
sample_size = arch_config.sample_size
use_quant_conv = arch_config.use_quant_conv
use_post_quant_conv = arch_config.use_post_quant_conv
mid_block_add_attention = arch_config.mid_block_add_attention
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
mid_block_add_attention=mid_block_add_attention,
)
# pass init params to Decoder
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
mid_block_add_attention=mid_block_add_attention,
)
self.quant_conv = (
nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
if use_quant_conv
else None
)
self.post_quant_conv = (
nn.Conv2d(latent_channels, latent_channels, 1)
if use_post_quant_conv
else None
)
self.use_slicing = False
self.use_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_size = sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(
sample_size / (2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_overlap_factor = 0.25
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(
name: str,
module: torch.nn.Module,
processors: Dict[str, AttentionProcessor],
):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnAddedKVProcessor()
elif all(
proc.__class__ in CROSS_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (
width > self.tile_sample_min_size or height > self.tile_sample_min_size
):
return self._tiled_encode(x)
enc = self.encoder(x)
if self.quant_conv is not None:
enc = self.quant_conv(enc)
return enc
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(
self, z: torch.Tensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
if self.use_tiling and (
z.shape[-1] > self.tile_latent_min_size
or z.shape[-2] > self.tile_latent_min_size
):
return self.tiled_decode(z, return_dict=return_dict)
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def decode(self, z: torch.FloatTensor) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
return decoded
def blend_v(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[
:, :, y, :
] * (y / blend_extent)
return b
def blend_h(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[
:, :, :, x
] * (x / blend_extent)
return b
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of images.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[
:,
:,
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = self.encoder(tile)
if self.config.use_quant_conv:
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
enc = torch.cat(result_rows, dim=2)
return enc
def tiled_encode(
self, x: torch.Tensor, return_dict: bool = True
) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
"""
deprecation_message = (
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
)
# deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[
:,
:,
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = self.encoder(tile)
if self.config.use_quant_conv:
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
moments = torch.cat(result_rows, dim=2)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(
self, z: torch.Tensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[
:,
:,
i : i + self.tile_latent_min_size,
j : j + self.tile_latent_min_size,
]
if self.config.use_post_quant_conv:
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
dec = torch.cat(result_rows, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
return dec
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
> [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError(
"`fuse_qkv_projections()` is not supported for models having added KV projections."
)
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedAttnProcessor2_0())
EntryClass = AutoencoderKL
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.activations import get_activation
from diffusers.models.autoencoders.vae import (
DecoderOutput,
DiagonalGaussianDistribution,
)
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from sglang.multimodal_gen.configs.models.vaes.qwenimage import QwenImageVAEConfig
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__) # pylint: disable=invalid-name
CACHE_T = 2
class QwenImageCausalConv3d(nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
caching for efficient inference.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
# Set up causal padding
self._padding = (
self.padding[2],
self.padding[2],
self.padding[1],
self.padding[1],
2 * self.padding[0],
0,
)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class QwenImageRMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(
self,
dim: int,
channel_first: bool = True,
images: bool = True,
bias: bool = False,
) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return (
F.normalize(x, dim=(1 if self.channel_first else -1))
* self.scale
* self.gamma
+ self.bias
)
class QwenImageUpsample(nn.Upsample):
r"""
Perform upsampling while ensuring the output tensor has the same data type as the input.
Returns:
torch.Tensor: Upsampled tensor with the same data type as the input.
"""
def forward(self, x):
return super().forward(x.float()).type_as(x)
class QwenImageResample(nn.Module):
r"""
A custom resampling module for 2D and 3D data.
Args:
dim (int): The number of input/output channels.
mode (str): The resampling mode. Must be one of:
- 'none': No resampling (identity operation).
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
- 'downsample2d': 2D downsampling with zero-padding and convolution.
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
def __init__(self, dim: int, mode: str) -> None:
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, dim // 2, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, dim // 2, 3, padding=1),
)
self.time_conv = QwenImageCausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)
)
elif mode == "downsample2d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
elif mode == "downsample3d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
self.time_conv = QwenImageCausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if (
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] != "Rep"
):
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
if (
cache_x.shape[2] < 2
and feat_cache[idx] is not None
and feat_cache[idx] == "Rep"
):
cache_x = torch.cat(
[torch.zeros_like(cache_x).to(cache_x.device), cache_x],
dim=2,
)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.resample(x)
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)
)
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class QwenImageResidualBlock(nn.Module):
r"""
A custom residual block module.
Args:
in_dim (int): Number of input channels.
out_dim (int): Number of output channels.
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def __init__(
self,
in_dim: int,
out_dim: int,
dropout: float = 0.0,
non_linearity: str = "silu",
) -> None:
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.nonlinearity = get_activation(non_linearity)
# layers
self.norm1 = QwenImageRMS_norm(in_dim, images=False)
self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
self.norm2 = QwenImageRMS_norm(out_dim, images=False)
self.dropout = nn.Dropout(dropout)
self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
self.conv_shortcut = (
QwenImageCausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim
else nn.Identity()
)
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Apply shortcut connection
h = self.conv_shortcut(x)
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
# Second normalization and activation
x = self.norm2(x)
x = self.nonlinearity(x)
# Dropout
x = self.dropout(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv2(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv2(x)
# Add residual connection
return x + h
class QwenImageAttentionBlock(nn.Module):
r"""
Causal self-attention with a single head.
Args:
dim (int): The number of channels in the input tensor.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = QwenImageRMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
def forward(self, x):
identity = x
batch_size, channels, time, height, width = x.size()
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
x = self.norm(x)
# compute query, key, value
qkv = self.to_qkv(x)
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
qkv = qkv.permute(0, 1, 3, 2).contiguous()
q, k, v = qkv.chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(q, k, v)
x = (
x.squeeze(1)
.permute(0, 2, 1)
.reshape(batch_size * time, channels, height, width)
)
# output projection
x = self.proj(x)
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
x = x.view(batch_size, time, channels, height, width)
x = x.permute(0, 2, 1, 3, 4)
return x + identity
class QwenImageMidBlock(nn.Module):
"""
Middle block for QwenImageVAE encoder and decoder.
Args:
dim (int): Number of input/output channels.
dropout (float): Dropout rate.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim: int,
dropout: float = 0.0,
non_linearity: str = "silu",
num_layers: int = 1,
):
super().__init__()
self.dim = dim
# Create the components
resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
attentions = []
for _ in range(num_layers):
attentions.append(QwenImageAttentionBlock(dim))
resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
x = self.resnets[0](x, feat_cache, feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
x = resnet(x, feat_cache, feat_idx)
return x
class QwenImageEncoder3d(nn.Module):
r"""
A 3D encoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_downsample (list of bool): Whether to downsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
non_linearity: str = "silu",
):
super().__init__()
# dim = config.arch_config.dim
# z_dim = config.arch_config.z_dim
# dim_mult = config.arch_config.dim_mult
# num_res_blocks = config.arch_config.num_res_blocks
# attn_scales = config.arch_config.attn_scales
# temperal_downsample = config.arch_config.temperal_downsample
# dropout = config.arch_config.dropout
# non_linearity = config.arch_config.non_linearity
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.nonlinearity = get_activation(non_linearity)
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
self.down_blocks.append(
QwenImageResidualBlock(in_dim, out_dim, dropout)
)
if scale in attn_scales:
self.down_blocks.append(QwenImageAttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
scale /= 2.0
# middle blocks
self.mid_block = QwenImageMidBlock(
out_dim, dropout, non_linearity, num_layers=1
)
# output blocks
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
class QwenImageUpBlock(nn.Module):
"""
A block that handles upsampling for the QwenImageVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
non_linearity (str): Type of non-linearity to use
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_res_blocks: int,
dropout: float = 0.0,
upsample_mode: Optional[str] = None,
non_linearity: str = "silu",
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# Create layers list
resnets = []
# Add residual blocks and attention if needed
current_dim = in_dim
for _ in range(num_res_blocks + 1):
resnets.append(
QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)
)
current_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add upsampling layer if needed
self.upsamplers = None
if upsample_mode is not None:
self.upsamplers = nn.ModuleList(
[QwenImageResample(out_dim, mode=upsample_mode)]
)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
x = self.upsamplers[0](x, feat_cache, feat_idx)
else:
x = self.upsamplers[0](x)
return x
class QwenImageDecoder3d(nn.Module):
r"""
A 3D decoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_upsample (list of bool): Whether to upsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
non_linearity: str = "silu",
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
self.nonlinearity = get_activation(non_linearity)
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.mid_block = QwenImageMidBlock(
dims[0], dropout, non_linearity, num_layers=1
)
# upsample blocks
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i > 0:
in_dim = in_dim // 2
# Determine if we need upsampling
upsample_mode = None
if i != len(dim_mult) - 1:
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
# Create and add the upsampling block
up_block = QwenImageUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
self.up_blocks.append(up_block)
# Update scale for next iteration
if upsample_mode is not None:
scale *= 2.0
# output blocks
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
class AutoencoderKLQwenImage(nn.Module):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = False
# fmt: off
def __init__(
self,
config: QwenImageVAEConfig,
# base_dim: int = 96,
# z_dim: int = 16,
# dim_mult: Tuple[int] = [1, 2, 4, 4],
# num_res_blocks: int = 2,
# attn_scales: List[float] = [],
# temperal_downsample: List[bool] = [False, True, True],
# dropout: float = 0.0,
# latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134,
# -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
# latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526,
# 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
) -> None:
# fmt: on
super().__init__()
base_dim = config.arch_config.base_dim
z_dim = config.arch_config.z_dim
dim_mult = config.arch_config.dim_mult
num_res_blocks = config.arch_config.num_res_blocks
attn_scales = config.arch_config.attn_scales
temperal_downsample = config.arch_config.temperal_downsample
dropout = config.arch_config.dropout
# non_linearity = config.arch_config.non_linearity
self.z_dim = z_dim
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
self.encoder = QwenImageEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
)
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
self.decoder = QwenImageDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
)
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
self._cached_conv_counts = {
"decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules())
if self.decoder is not None
else 0,
"encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules())
if self.encoder is not None
else 0,
}
cuda_device = get_local_torch_device()
# FIXME: hardcode
dtype = torch.bfloat16
latent_channels = config.arch_config.z_dim
self.shift_factor = (
torch.tensor(
config.arch_config.latents_mean
)
.view(1, latent_channels, 1, 1, 1)
.to(cuda_device, dtype)
)
latents_std_tensor = torch.tensor(config.arch_config.latents_std, dtype=dtype, device=cuda_device)
self.scaling_factor = (1.0 / latents_std_tensor).view(1, latent_channels, 1, 1, 1)
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def clear_cache(self):
def _count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, QwenImageCausalConv3d):
count += 1
return count
self._conv_num = _count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = _count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _encode(self, x: torch.Tensor):
_, _, num_frame, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
self.clear_cache()
iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1): 1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
enc = self.quant_conv(out)
self.clear_cache()
return enc
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> DiagonalGaussianDistribution:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return posterior
def _decode(self, z: torch.Tensor, return_dict: bool = True):
_, _, num_frame, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
self.clear_cache()
x = self.post_quant_conv(z)
for i in range(num_frame):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i: i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i: i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
out = torch.clamp(out, min=-1.0, max=1.0)
self.clear_cache()
if not return_dict:
return (out,)
return DecoderOutput(sample=out)
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
return decoded
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, num_frames, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
self.clear_cache()
time = []
frame_range = 1 + (num_frames - 1) // 4
for k in range(frame_range):
self._enc_conv_idx = [0]
if k == 0:
tile = x[:, :, :1, i: i + self.tile_sample_min_height, j: j + self.tile_sample_min_width]
else:
tile = x[
:,
:,
1 + 4 * (k - 1): 1 + 4 * k,
i: i + self.tile_sample_min_height,
j: j + self.tile_sample_min_width,
]
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
tile = self.quant_conv(tile)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
self.clear_cache()
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_, _, num_frames, height, width = z.shape
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
self.clear_cache()
time = []
for k in range(num_frames):
self._conv_idx = [0]
tile = z[:, :, k: k + 1, i: i + tile_latent_min_height, j: j + tile_latent_min_width]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)
self.clear_cache()
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
"""
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
EntryClass = AutoencoderKLQwenImage
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from collections.abc import Iterator
from math import prod
from typing import Optional, cast
import numpy as np
import torch
import torch.distributed as dist
from diffusers.utils.torch_utils import randn_tensor
from sglang.multimodal_gen.configs.models import VAEConfig
from sglang.multimodal_gen.runtime.distributed import (
get_sp_parallel_rank,
get_sp_world_size,
)
class ParallelTiledVAE(ABC):
tile_sample_min_height: int
tile_sample_min_width: int
tile_sample_min_num_frames: int
tile_sample_stride_height: int
tile_sample_stride_width: int
tile_sample_stride_num_frames: int
blend_num_frames: int
use_tiling: bool
use_temporal_tiling: bool
use_parallel_tiling: bool
def __init__(self, config: VAEConfig, **kwargs) -> None:
self.config = config
self.tile_sample_min_height = config.tile_sample_min_height
self.tile_sample_min_width = config.tile_sample_min_width
self.tile_sample_min_num_frames = config.tile_sample_min_num_frames
self.tile_sample_stride_height = config.tile_sample_stride_height
self.tile_sample_stride_width = config.tile_sample_stride_width
self.tile_sample_stride_num_frames = config.tile_sample_stride_num_frames
self.blend_num_frames = config.blend_num_frames
self.use_tiling = config.use_tiling
self.use_temporal_tiling = config.use_temporal_tiling
self.use_parallel_tiling = config.use_parallel_tiling
def to(self, device) -> "ParallelTiledVAE":
# TODO: implement this
return self
@property
def device(self):
return next(self.parameters()).device
@property
def temporal_compression_ratio(self) -> int:
return cast(int, self.config.temporal_compression_ratio)
@property
def spatial_compression_ratio(self) -> int:
return cast(int, self.config.spatial_compression_ratio)
@property
def scaling_factor(self) -> float | torch.Tensor:
return cast(float | torch.Tensor, self.config.scaling_factor)
@abstractmethod
def _encode(self, *args, **kwargs) -> torch.Tensor:
pass
@abstractmethod
def _decode(self, *args, **kwargs) -> torch.Tensor:
pass
def encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
if (
self.use_tiling
and self.use_temporal_tiling
and num_frames > self.tile_sample_min_num_frames
):
latents = self.tiled_encode(x)[:, :, :latent_num_frames]
elif self.use_tiling and (
width > self.tile_sample_min_width or height > self.tile_sample_min_height
):
latents = self.spatial_tiled_encode(x)[:, :, :latent_num_frames]
else:
latents = self._encode(x)[:, :, :latent_num_frames]
return DiagonalGaussianDistribution(latents)
def decode(self, z: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = (
self.tile_sample_min_height // self.spatial_compression_ratio
)
tile_latent_min_width = (
self.tile_sample_stride_width // self.spatial_compression_ratio
)
tile_latent_min_num_frames = (
self.tile_sample_min_num_frames // self.temporal_compression_ratio
)
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
if self.use_tiling and self.use_parallel_tiling and get_sp_world_size() > 1:
return self.parallel_tiled_decode(z)[:, :, :num_sample_frames]
if (
self.use_tiling
and self.use_temporal_tiling
and num_frames > tile_latent_min_num_frames
):
return self.tiled_decode(z)[:, :, :num_sample_frames]
if self.use_tiling and (
width > tile_latent_min_width or height > tile_latent_min_height
):
return self.spatial_tiled_decode(z)[:, :, :num_sample_frames]
return self._decode(z)[:, :, :num_sample_frames]
def blend_v(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
1 - y / blend_extent
) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
1 - x / blend_extent
) + b[:, :, :, :, x] * (x / blend_extent)
return b
def blend_t(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (
1 - x / blend_extent
) + b[:, :, x, :, :] * (x / blend_extent)
return b
def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, _, height, width = x.shape
# latent_height = height // self.spatial_compression_ratio
# latent_width = width // self.spatial_compression_ratio
tile_latent_min_height = (
self.tile_sample_min_height // self.spatial_compression_ratio
)
tile_latent_min_width = (
self.tile_sample_min_width // self.spatial_compression_ratio
)
tile_latent_stride_height = (
self.tile_sample_stride_height // self.spatial_compression_ratio
)
tile_latent_stride_width = (
self.tile_sample_stride_width // self.spatial_compression_ratio
)
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self._encode(tile)
row.append(tile)
rows.append(row)
return self._merge_spatial_tiles(
rows,
blend_height,
blend_width,
tile_latent_stride_height,
tile_latent_stride_width,
)
def _parallel_data_generator(
self, gathered_results, gathered_dim_metadata
) -> Iterator[tuple[torch.Tensor, int]]:
global_idx = 0
for i, per_rank_metadata in enumerate(gathered_dim_metadata):
_start_shape = 0
for shape in per_rank_metadata:
mul_shape = prod(shape)
yield (
gathered_results[
i, _start_shape : _start_shape + mul_shape
].reshape(shape),
global_idx,
)
_start_shape += mul_shape
global_idx += 1
def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
"""
Parallel version of tiled_decode that distributes both temporal and spatial computation across GPUs
"""
world_size, rank = get_sp_world_size(), get_sp_parallel_rank()
B, C, T, H, W = z.shape
# Calculate parameters
tile_latent_min_height = (
self.tile_sample_min_height // self.spatial_compression_ratio
)
tile_latent_min_width = (
self.tile_sample_min_width // self.spatial_compression_ratio
)
tile_latent_min_num_frames = (
self.tile_sample_min_num_frames // self.temporal_compression_ratio
)
tile_latent_stride_height = (
self.tile_sample_stride_height // self.spatial_compression_ratio
)
tile_latent_stride_width = (
self.tile_sample_stride_width // self.spatial_compression_ratio
)
tile_latent_stride_num_frames = (
self.tile_sample_stride_num_frames // self.temporal_compression_ratio
)
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Calculate tile dimensions
num_t_tiles = (
T + tile_latent_stride_num_frames - 1
) // tile_latent_stride_num_frames
num_h_tiles = (H + tile_latent_stride_height - 1) // tile_latent_stride_height
num_w_tiles = (W + tile_latent_stride_width - 1) // tile_latent_stride_width
total_spatial_tiles = num_h_tiles * num_w_tiles
total_tiles = num_t_tiles * total_spatial_tiles
# Calculate tiles per rank and padding
tiles_per_rank = (total_tiles + world_size - 1) // world_size
start_tile_idx = rank * tiles_per_rank
end_tile_idx = min((rank + 1) * tiles_per_rank, total_tiles)
local_results = []
local_dim_metadata = []
# Process assigned tiles
for local_idx, global_idx in enumerate(range(start_tile_idx, end_tile_idx)):
t_idx = global_idx // total_spatial_tiles
spatial_idx = global_idx % total_spatial_tiles
h_idx = spatial_idx // num_w_tiles
w_idx = spatial_idx % num_w_tiles
# Calculate positions
t_start = t_idx * tile_latent_stride_num_frames
h_start = h_idx * tile_latent_stride_height
w_start = w_idx * tile_latent_stride_width
# Extract and process tile
tile = z[
:,
:,
t_start : t_start + tile_latent_min_num_frames + 1,
h_start : h_start + tile_latent_min_height,
w_start : w_start + tile_latent_min_width,
]
# Process tile
tile = self._decode(tile)
if t_start > 0:
tile = tile[:, :, 1:, :, :]
# Store metadata
shape = tile.shape
# Store decoded data (flattened)
decoded_flat = tile.reshape(-1)
local_results.append(decoded_flat)
local_dim_metadata.append(shape)
results = torch.cat(local_results, dim=0).contiguous()
del local_results
# first gather size to pad the results
local_size = torch.tensor(
[results.size(0)], device=results.device, dtype=torch.int64
)
all_sizes = [
torch.zeros(1, device=results.device, dtype=torch.int64)
for _ in range(world_size)
]
dist.all_gather(all_sizes, local_size)
max_size = max(size.item() for size in all_sizes)
padded_results = torch.zeros(max_size, device=results.device)
padded_results[: results.size(0)] = results
del results
# Gather all results
gathered_dim_metadata = [None] * world_size
gathered_results = (
torch.zeros_like(padded_results)
.repeat(world_size, *[1] * len(padded_results.shape))
.contiguous()
) # use contiguous to make sure it won't copy data in the following operations
# TODO (PY): use sgl_diffusion distributed methods
dist.all_gather_into_tensor(gathered_results, padded_results)
dist.all_gather_object(gathered_dim_metadata, local_dim_metadata)
# Process gathered results
data: list = [
[[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)]
for _ in range(num_t_tiles)
]
for current_data, global_idx in self._parallel_data_generator(
gathered_results, gathered_dim_metadata
):
t_idx = global_idx // total_spatial_tiles
spatial_idx = global_idx % total_spatial_tiles
h_idx = spatial_idx // num_w_tiles
w_idx = spatial_idx % num_w_tiles
data[t_idx][h_idx][w_idx] = current_data
# Merge results
result_slices = []
last_slice_data = None
for i, tem_data in enumerate(data):
slice_data = self._merge_spatial_tiles(
tem_data,
blend_height,
blend_width,
self.tile_sample_stride_height,
self.tile_sample_stride_width,
)
if i > 0:
slice_data = self.blend_t(
last_slice_data, slice_data, self.blend_num_frames
)
result_slices.append(
slice_data[:, :, : self.tile_sample_stride_num_frames, :, :]
)
else:
result_slices.append(
slice_data[:, :, : self.tile_sample_stride_num_frames + 1, :, :]
)
last_slice_data = slice_data
dec = torch.cat(result_slices, dim=2)
return dec
def _merge_spatial_tiles(
self, tiles, blend_height, blend_width, stride_height, stride_width
) -> torch.Tensor:
"""Helper function to merge spatial tiles with blending"""
result_rows = []
for i, row in enumerate(tiles):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(tiles[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :stride_height, :stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
return torch.cat(result_rows, dim=-2)
def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
Returns:
`torch.Tensor`:
The decoded images.
"""
_, _, _, height, width = z.shape
# sample_height = height * self.spatial_compression_ratio
# sample_width = width * self.spatial_compression_ratio
tile_latent_min_height = (
self.tile_sample_min_height // self.spatial_compression_ratio
)
tile_latent_min_width = (
self.tile_sample_min_width // self.spatial_compression_ratio
)
tile_latent_stride_height = (
self.tile_sample_stride_height // self.spatial_compression_ratio
)
tile_latent_stride_width = (
self.tile_sample_stride_width // self.spatial_compression_ratio
)
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
tile = z[
:,
:,
:,
i : i + tile_latent_min_height,
j : j + tile_latent_min_width,
]
decoded = self._decode(tile)
row.append(decoded)
rows.append(row)
return self._merge_spatial_tiles(
rows,
blend_height,
blend_width,
self.tile_sample_stride_height,
self.tile_sample_stride_width,
)
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, num_frames, height, width = x.shape
# tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
tile_latent_stride_num_frames = (
self.tile_sample_stride_num_frames // self.temporal_compression_ratio
)
row = []
for i in range(0, num_frames, self.tile_sample_stride_num_frames):
tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
if self.use_tiling and (
height > self.tile_sample_min_height
or width > self.tile_sample_min_width
):
tile = self.spatial_tiled_encode(tile)
else:
tile = self._encode(tile)
if i > 0:
tile = tile[:, :, 1:, :, :]
row.append(tile)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, self.blend_num_frames)
result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
else:
result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
enc = torch.cat(result_row, dim=2)
return enc
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = (
self.tile_sample_min_height // self.spatial_compression_ratio
)
tile_latent_min_width = (
self.tile_sample_min_width // self.spatial_compression_ratio
)
tile_latent_min_num_frames = (
self.tile_sample_min_num_frames // self.temporal_compression_ratio
)
tile_latent_stride_num_frames = (
self.tile_sample_stride_num_frames // self.temporal_compression_ratio
)
row = []
for i in range(0, num_frames, tile_latent_stride_num_frames):
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
if self.use_tiling and (
tile.shape[-1] > tile_latent_min_width
or tile.shape[-2] > tile_latent_min_height
):
decoded = self.spatial_tiled_decode(tile)
else:
decoded = self._decode(tile)
if i > 0:
decoded = decoded[:, :, 1:, :, :]
row.append(decoded)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, self.blend_num_frames)
result_row.append(
tile[:, :, : self.tile_sample_stride_num_frames, :, :]
)
else:
result_row.append(
tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]
)
dec = torch.cat(result_row, dim=2)
return dec
def enable_tiling(
self,
tile_sample_min_height: int | None = None,
tile_sample_min_width: int | None = None,
tile_sample_min_num_frames: int | None = None,
tile_sample_stride_height: int | None = None,
tile_sample_stride_width: int | None = None,
tile_sample_stride_num_frames: int | None = None,
blend_num_frames: int | None = None,
use_tiling: bool | None = None,
use_temporal_tiling: bool | None = None,
use_parallel_tiling: bool | None = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_min_num_frames (`int`, *optional*):
The minimum number of frames required for a sample to be separated into tiles across the frame
dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
tile_sample_stride_num_frames (`int`, *optional*):
The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts
produced across the frame dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = (
tile_sample_min_height or self.tile_sample_min_height
)
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_min_num_frames = (
tile_sample_min_num_frames or self.tile_sample_min_num_frames
)
self.tile_sample_stride_height = (
tile_sample_stride_height or self.tile_sample_stride_height
)
self.tile_sample_stride_width = (
tile_sample_stride_width or self.tile_sample_stride_width
)
self.tile_sample_stride_num_frames = (
tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
)
if blend_num_frames is not None:
self.blend_num_frames = blend_num_frames
else:
self.blend_num_frames = (
self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
)
self.use_tiling = use_tiling or self.use_tiling
self.use_temporal_tiling = use_temporal_tiling or self.use_temporal_tiling
self.use_parallel_tiling = use_parallel_tiling or self.use_parallel_tiling
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
# adapted from https://github.com/huggingface/diffusers/blob/e7ffeae0a191f710881d1fbde00cd6ff025e81f2/src/diffusers/models/autoencoders/vae.py#L691
class DiagonalGaussianDistribution:
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: torch.Generator | None = None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(
self, other: Optional["DiagonalGaussianDistribution"] = None
) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(
self, sample: torch.Tensor, dims: tuple[int, ...] = (1, 2, 3)
) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from diffusers
# Copyright 2024 The Hunyuan Team, The HuggingFace Team and The sgl-diffusion Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sglang.multimodal_gen.configs.models.vaes import HunyuanVAEConfig
from sglang.multimodal_gen.runtime.layers.activation import get_act_fn
from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE
def prepare_causal_attention_mask(
num_frames: int,
height_width: int,
dtype: torch.dtype,
device: torch.device,
batch_size: int | None = None,
) -> torch.Tensor:
indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device)
indices_blocks = indices.repeat_interleave(height_width)
x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy")
mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype)
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
class HunyuanVAEAttention(nn.Module):
def __init__(
self, in_channels, heads, dim_head, eps, norm_num_groups, bias
) -> None:
super().__init__()
self.in_channels = in_channels
self.heads = heads
self.dim_head = dim_head
self.eps = eps
self.norm_num_groups = norm_num_groups
self.bias = bias
inner_dim = heads * dim_head
# Define the projection layers
self.to_q = nn.Linear(in_channels, inner_dim, bias=bias)
self.to_k = nn.Linear(in_channels, inner_dim, bias=bias)
self.to_v = nn.Linear(in_channels, inner_dim, bias=bias)
self.to_out = nn.Sequential(nn.Linear(inner_dim, in_channels, bias=bias))
# Optional normalization layers
self.group_norm = nn.GroupNorm(
norm_num_groups, in_channels, eps=eps, affine=True
)
def forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
residual = hidden_states
batch_size, sequence_length, _ = hidden_states.shape
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
# Project to query, key, value
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
# Reshape for multi-head attention
head_dim = self.dim_head
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
# Perform scaled dot-product attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# Reshape back
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, self.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# Linear projection
hidden_states = self.to_out(hidden_states)
# Residual connection and rescale
hidden_states = hidden_states + residual
return hidden_states
class HunyuanVideoCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int, int] = 3,
stride: int | tuple[int, int, int] = 1,
padding: int | tuple[int, int, int] = 0,
dilation: int | tuple[int, int, int] = 1,
bias: bool = True,
pad_mode: str = "replicate",
) -> None:
super().__init__()
kernel_size = (
(kernel_size, kernel_size, kernel_size)
if isinstance(kernel_size, int)
else kernel_size
)
self.pad_mode = pad_mode
self.time_causal_padding = (
kernel_size[0] // 2,
kernel_size[0] // 2,
kernel_size[1] // 2,
kernel_size[1] // 2,
kernel_size[2] - 1,
0,
)
self.conv = nn.Conv3d(
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(
hidden_states, self.time_causal_padding, mode=self.pad_mode
)
return self.conv(hidden_states)
class HunyuanVideoUpsampleCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int | None = None,
kernel_size: int = 3,
stride: int = 1,
bias: bool = True,
upsample_factor: tuple[int, ...] = (2, 2, 2),
) -> None:
super().__init__()
out_channels = out_channels or in_channels
self.upsample_factor = upsample_factor
self.conv = HunyuanVideoCausalConv3d(
in_channels, out_channels, kernel_size, stride, bias=bias
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_frames = hidden_states.size(2)
first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2)
first_frame = F.interpolate(
first_frame.squeeze(2),
scale_factor=self.upsample_factor[1:],
mode="nearest",
).unsqueeze(2)
if num_frames > 1:
# See: https://github.com/pytorch/pytorch/issues/81665
# Unless you have a version of pytorch where non-contiguous implementation of F.interpolate
# is fixed, this will raise either a runtime error, or fail silently with bad outputs.
# If you are encountering an error here, make sure to try running encoding/decoding with
# `vae.enable_tiling()` first. If that doesn't work, open an issue at:
# https://github.com/huggingface/diffusers/issues
other_frames = other_frames.contiguous()
other_frames = F.interpolate(
other_frames, scale_factor=self.upsample_factor, mode="nearest"
)
hidden_states = torch.cat((first_frame, other_frames), dim=2)
else:
hidden_states = first_frame
hidden_states = self.conv(hidden_states)
return hidden_states
class HunyuanVideoDownsampleCausal3D(nn.Module):
def __init__(
self,
channels: int,
out_channels: int | None = None,
padding: int = 1,
kernel_size: int = 3,
bias: bool = True,
stride=2,
) -> None:
super().__init__()
out_channels = out_channels or channels
self.conv = HunyuanVideoCausalConv3d(
channels, out_channels, kernel_size, stride, padding, bias=bias
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv(hidden_states)
return hidden_states
class HunyuanVideoResnetBlockCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int | None = None,
dropout: float = 0.0,
groups: int = 32,
eps: float = 1e-6,
non_linearity: str = "silu",
) -> None:
super().__init__()
out_channels = out_channels or in_channels
self.nonlinearity = get_act_fn(non_linearity)
self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True)
self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0)
self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True)
self.dropout = nn.Dropout(dropout)
self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0)
self.conv_shortcut = None
if in_channels != out_channels:
self.conv_shortcut = HunyuanVideoCausalConv3d(
in_channels, out_channels, 1, 1, 0
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.contiguous()
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
hidden_states = hidden_states + residual
return hidden_states
class HunyuanVideoMidBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "silu",
resnet_groups: int = 32,
add_attention: bool = True,
attention_head_dim: int = 1,
) -> None:
super().__init__()
resnet_groups = (
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
)
self.add_attention = add_attention
# There is always at least one resnet
resnets = [
HunyuanVideoResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
non_linearity=resnet_act_fn,
)
]
attentions: list[HunyuanVAEAttention | None] = []
for _ in range(num_layers):
if self.add_attention:
attentions.append(
HunyuanVAEAttention(
in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
eps=resnet_eps,
norm_num_groups=resnet_groups,
bias=True,
)
)
else:
attentions.append(None)
resnets.append(
HunyuanVideoResnetBlockCausal3D(
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
non_linearity=resnet_act_fn,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
self.resnets[0], hidden_states
)
for attn, resnet in zip(self.attentions, self.resnets[1:], strict=True):
if attn is not None:
batch_size, num_channels, num_frames, height, width = (
hidden_states.shape
)
hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
attention_mask = prepare_causal_attention_mask(
num_frames,
height * width,
hidden_states.dtype,
hidden_states.device,
batch_size=batch_size,
)
hidden_states = attn(hidden_states, attention_mask=attention_mask)
hidden_states = hidden_states.unflatten(
1, (num_frames, height, width)
).permute(0, 4, 1, 2, 3)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
else:
hidden_states = self.resnets[0](hidden_states)
for attn, resnet in zip(self.attentions, self.resnets[1:], strict=True):
if attn is not None:
batch_size, num_channels, num_frames, height, width = (
hidden_states.shape
)
hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
attention_mask = prepare_causal_attention_mask(
num_frames,
height * width,
hidden_states.dtype,
hidden_states.device,
batch_size=batch_size,
)
hidden_states = attn(hidden_states, attention_mask=attention_mask)
hidden_states = hidden_states.unflatten(
1, (num_frames, height, width)
).permute(0, 4, 1, 2, 3)
hidden_states = resnet(hidden_states)
return hidden_states
class HunyuanVideoDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "silu",
resnet_groups: int = 32,
add_downsample: bool = True,
downsample_stride: tuple[int, ...] | int = 2,
downsample_padding: int = 1,
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanVideoResnetBlockCausal3D(
in_channels=in_channels,
out_channels=out_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
non_linearity=resnet_act_fn,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
HunyuanVideoDownsampleCausal3D(
out_channels,
out_channels=out_channels,
padding=downsample_padding,
stride=downsample_stride,
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing:
for resnet in self.resnets:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
else:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class HunyuanVideoUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "silu",
resnet_groups: int = 32,
add_upsample: bool = True,
upsample_scale_factor: tuple[int, ...] = (2, 2, 2),
) -> None:
super().__init__()
resnets = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
HunyuanVideoResnetBlockCausal3D(
in_channels=input_channels,
out_channels=out_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
non_linearity=resnet_act_fn,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList(
[
HunyuanVideoUpsampleCausal3D(
out_channels,
out_channels=out_channels,
upsample_factor=upsample_scale_factor,
)
]
)
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if torch.is_grad_enabled() and self.gradient_checkpointing:
for resnet in self.resnets:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
else:
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class HunyuanVideoEncoder3D(nn.Module):
r"""
Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: tuple[str, ...] = (
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
),
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
mid_block_add_attention=True,
temporal_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
) -> None:
super().__init__()
self.conv_in = HunyuanVideoCausalConv3d(
in_channels, block_out_channels[0], kernel_size=3, stride=1
)
self.mid_block: HunyuanVideoMidBlock3D | None = None
self.down_blocks = nn.ModuleList([])
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
if down_block_type != "HunyuanVideoDownBlock3D":
raise ValueError(f"Unsupported down_block_type: {down_block_type}")
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
num_time_downsample_layers = int(np.log2(temporal_compression_ratio))
if temporal_compression_ratio == 4:
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(
i >= (len(block_out_channels) - 1 - num_time_downsample_layers)
and not is_final_block
)
elif temporal_compression_ratio == 8:
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(i < num_time_downsample_layers)
else:
raise ValueError(
f"Unsupported time_compression_ratio: {temporal_compression_ratio}"
)
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
downsample_stride_T = (2,) if add_time_downsample else (1,)
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
down_block = HunyuanVideoDownBlock3D(
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=bool(add_spatial_downsample or add_time_downsample),
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
downsample_stride=downsample_stride,
downsample_padding=0,
)
self.down_blocks.append(down_block)
self.mid_block = HunyuanVideoMidBlock3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
add_attention=mid_block_add_attention,
)
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = HunyuanVideoCausalConv3d(
block_out_channels[-1], conv_out_channels, kernel_size=3
)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for down_block in self.down_blocks:
hidden_states = self._gradient_checkpointing_func(
down_block, hidden_states
)
hidden_states = self._gradient_checkpointing_func(
self.mid_block, hidden_states
)
else:
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
assert self.mid_block is not None
hidden_states = self.mid_block(hidden_states)
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class HunyuanVideoDecoder3D(nn.Module):
r"""
Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: tuple[str, ...] = (
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
block_out_channels: tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
mid_block_add_attention=True,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = HunyuanVideoCausalConv3d(
in_channels, block_out_channels[-1], kernel_size=3, stride=1
)
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = HunyuanVideoMidBlock3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
add_attention=mid_block_add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
if up_block_type != "HunyuanVideoUpBlock3D":
raise ValueError(f"Unsupported up_block_type: {up_block_type}")
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
num_time_upsample_layers = int(np.log2(time_compression_ratio))
if time_compression_ratio == 4:
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool(
i >= len(block_out_channels) - 1 - num_time_upsample_layers
and not is_final_block
)
else:
raise ValueError(
f"Unsupported time_compression_ratio: {time_compression_ratio}"
)
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
upsample_scale_factor = tuple(
upsample_scale_factor_T + upsample_scale_factor_HW
)
up_block = HunyuanVideoUpBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
add_upsample=bool(add_spatial_upsample or add_time_upsample),
upsample_scale_factor=upsample_scale_factor,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
)
self.conv_act = nn.SiLU()
self.conv_out = HunyuanVideoCausalConv3d(
block_out_channels[0], out_channels, kernel_size=3
)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
self.mid_block, hidden_states
)
for up_block in self.up_blocks:
hidden_states = self._gradient_checkpointing_func(
up_block, hidden_states
)
else:
hidden_states = self.mid_block(hidden_states)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
# post-process
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKLHunyuanVideo(nn.Module, ParallelTiledVAE):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = True
def __init__(
self,
config: HunyuanVAEConfig,
) -> None:
nn.Module.__init__(self)
ParallelTiledVAE.__init__(self, config)
# TODO(will): only pass in config. We do this by manually defining a
# config for hunyuan vae
self.block_out_channels = config.block_out_channels
if config.load_encoder:
self.encoder = HunyuanVideoEncoder3D(
in_channels=config.in_channels,
out_channels=config.latent_channels,
down_block_types=config.down_block_types,
block_out_channels=config.block_out_channels,
layers_per_block=config.layers_per_block,
norm_num_groups=config.norm_num_groups,
act_fn=config.act_fn,
double_z=True,
mid_block_add_attention=config.mid_block_add_attention,
temporal_compression_ratio=config.temporal_compression_ratio,
spatial_compression_ratio=config.spatial_compression_ratio,
)
self.quant_conv = nn.Conv3d(
2 * config.latent_channels, 2 * config.latent_channels, kernel_size=1
)
if config.load_decoder:
self.decoder = HunyuanVideoDecoder3D(
in_channels=config.latent_channels,
out_channels=config.out_channels,
up_block_types=config.up_block_types,
block_out_channels=config.block_out_channels,
layers_per_block=config.layers_per_block,
norm_num_groups=config.norm_num_groups,
act_fn=config.act_fn,
time_compression_ratio=config.temporal_compression_ratio,
spatial_compression_ratio=config.spatial_compression_ratio,
mid_block_add_attention=config.mid_block_add_attention,
)
self.post_quant_conv = nn.Conv3d(
config.latent_channels, config.latent_channels, kernel_size=1
)
def _encode(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
enc = self.quant_conv(x)
return enc
def _decode(self, z: torch.Tensor) -> torch.Tensor:
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
generator: torch.Generator | None = None,
) -> torch.Tensor:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
return dec
EntryClass = AutoencoderKLHunyuanVideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 StepFun Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
from typing import Any
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from sglang.multimodal_gen.configs.models.vaes import StepVideoVAEConfig
from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE
def base_group_norm(x, norm_layer, act_silu=False, channel_last=False) -> torch.Tensor:
if hasattr(base_group_norm, "spatial") and base_group_norm.spatial:
assert channel_last
x_shape = x.shape
x = x.flatten(0, 1)
if channel_last:
# Permute to NCHW format
x = x.permute(0, 3, 1, 2)
out = F.group_norm(
x.contiguous(),
norm_layer.num_groups,
norm_layer.weight,
norm_layer.bias,
norm_layer.eps,
)
if act_silu:
out = F.silu(out)
if channel_last:
# Permute back to NHWC format
out = out.permute(0, 2, 3, 1)
out = out.view(x_shape)
else:
if channel_last:
# Permute to NCHW format
x = x.permute(0, 3, 1, 2)
out = F.group_norm(
x.contiguous(),
norm_layer.num_groups,
norm_layer.weight,
norm_layer.bias,
norm_layer.eps,
)
if act_silu:
out = F.silu(out)
if channel_last:
# Permute back to NHWC format
out = out.permute(0, 2, 3, 1)
return out
def base_conv2d(x, conv_layer, channel_last=False, residual=None) -> torch.Tensor:
if channel_last:
x = x.permute(0, 3, 1, 2) # NHWC to NCHW
out = F.conv2d(
x,
conv_layer.weight,
conv_layer.bias,
stride=conv_layer.stride,
padding=conv_layer.padding,
)
if residual is not None:
if channel_last:
residual = residual.permute(0, 3, 1, 2) # NHWC to NCHW
out += residual
if channel_last:
out = out.permute(0, 2, 3, 1) # NCHW to NHWC
return out
def base_conv3d(
x, conv_layer, channel_last=False, residual=None, only_return_output=False
) -> torch.Tensor:
if only_return_output:
size = cal_outsize(
x.shape, conv_layer.weight.shape, conv_layer.stride, conv_layer.padding
)
return torch.empty(size, device=x.device, dtype=x.dtype)
if channel_last:
x = x.permute(0, 4, 1, 2, 3) # NDHWC to NCDHW
out = F.conv3d(
x,
conv_layer.weight,
conv_layer.bias,
stride=conv_layer.stride,
padding=conv_layer.padding,
)
if residual is not None:
if channel_last:
residual = residual.permute(0, 4, 1, 2, 3) # NDHWC to NCDHW
out += residual
if channel_last:
out = out.permute(0, 2, 3, 4, 1) # NCDHW to NDHWC
return out
def cal_outsize(input_sizes, kernel_sizes, stride, padding) -> list:
stride_d, stride_h, stride_w = stride
padding_d, padding_h, padding_w = padding
dilation_d, dilation_h, dilation_w = 1, 1, 1
in_d = input_sizes[1]
in_h = input_sizes[2]
in_w = input_sizes[3]
kernel_d = kernel_sizes[2]
kernel_h = kernel_sizes[3]
kernel_w = kernel_sizes[4]
out_channels = kernel_sizes[0]
out_d = calc_out_(in_d, padding_d, dilation_d, kernel_d, stride_d)
out_h = calc_out_(in_h, padding_h, dilation_h, kernel_h, stride_h)
out_w = calc_out_(in_w, padding_w, dilation_w, kernel_w, stride_w)
size = [input_sizes[0], out_d, out_h, out_w, out_channels]
return size
def calc_out_(
in_size: int, padding: int, dilation: int, kernel: int, stride: int
) -> int:
return (in_size + 2 * padding - dilation * (kernel - 1) - 1) // stride + 1
def base_conv3d_channel_last(x, conv_layer, residual=None) -> torch.Tensor:
in_numel = x.numel()
out_numel = int(x.numel() * conv_layer.out_channels / conv_layer.in_channels)
if (in_numel >= 2**30) or (out_numel >= 2**30):
assert conv_layer.stride[0] == 1, "time split asks time stride = 1"
B, T, H, W, C = x.shape
K = conv_layer.kernel_size[0]
chunks = 4
chunk_size = T // chunks
if residual is None:
out_nhwc = base_conv3d(
x,
conv_layer,
channel_last=True,
residual=residual,
only_return_output=True,
)
else:
out_nhwc = residual
assert B == 1
for i in range(chunks):
if i == chunks - 1:
xi = x[:1, chunk_size * i :]
out_nhwci = out_nhwc[:1, chunk_size * i :]
else:
xi = x[:1, chunk_size * i : chunk_size * (i + 1) + K - 1]
out_nhwci = out_nhwc[:1, chunk_size * i : chunk_size * (i + 1)]
if residual is not None:
if i == chunks - 1:
ri = residual[:1, chunk_size * i :]
else:
ri = residual[:1, chunk_size * i : chunk_size * (i + 1)]
else:
ri = None
out_nhwci.copy_(base_conv3d(xi, conv_layer, channel_last=True, residual=ri))
else:
out_nhwc = base_conv3d(x, conv_layer, channel_last=True, residual=residual)
return out_nhwc
class Upsample2D(nn.Module):
def __init__(
self, channels, use_conv=False, use_conv_transpose=False, out_channels=None
) -> None:
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
if use_conv:
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
else:
assert "Not Supported"
self.conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
def forward(self, x, output_size=None) -> torch.Tensor:
assert x.shape[-1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
if output_size is None:
x = (
F.interpolate(
x.permute(0, 3, 1, 2).to(memory_format=torch.channels_last),
scale_factor=2.0,
mode="nearest",
)
.permute(0, 2, 3, 1)
.contiguous()
)
else:
x = (
F.interpolate(
x.permute(0, 3, 1, 2).to(memory_format=torch.channels_last),
size=output_size,
mode="nearest",
)
.permute(0, 2, 3, 1)
.contiguous()
)
# x = self.conv(x)
x = base_conv2d(x, self.conv, channel_last=True)
return x
class Downsample2D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, padding=1) -> None:
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
if use_conv:
self.conv = nn.Conv2d(
self.channels, self.out_channels, 3, stride=stride, padding=padding
)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
def forward(self, x) -> torch.Tensor:
assert x.shape[-1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 0, 0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
assert x.shape[-1] == self.channels
# x = self.conv(x)
x = base_conv2d(x, self.conv, channel_last=True)
return x
class CausalConv(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size, **kwargs) -> None:
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (
kernel_size if isinstance(kernel_size, tuple) else ((kernel_size,) * 3)
)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.dilation = kwargs.pop("dilation", 1)
self.stride = kwargs.pop("stride", 1)
if isinstance(self.stride, int):
self.stride = (self.stride, 1, 1)
time_pad = self.dilation * (time_kernel_size - 1) + max((1 - self.stride[0]), 0)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.time_causal_padding = (
width_pad,
width_pad,
height_pad,
height_pad,
time_pad,
0,
)
self.time_uncausal_padding = (
width_pad,
width_pad,
height_pad,
height_pad,
0,
0,
)
self.conv = nn.Conv3d(
chan_in,
chan_out,
kernel_size,
stride=self.stride,
dilation=self.dilation,
**kwargs,
)
self.chan_in = chan_in
self.chan_out = chan_out
self.is_first_run = True
def forward(self, x, is_init=True, residual=None) -> torch.Tensor:
x = nn.functional.pad(
x, self.time_causal_padding if is_init else self.time_uncausal_padding
)
x = self.conv(x)
if residual is not None:
x.add_(residual)
return x
class ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor: int,
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor = factor
assert out_channels * factor**3 % in_channels == 0
self.repeats = out_channels * factor**3 // in_channels
def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(
x.size(0),
self.out_channels,
self.factor,
self.factor,
self.factor,
x.size(2),
x.size(3),
x.size(4),
)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(
x.size(0),
self.out_channels,
x.size(2) * self.factor,
x.size(4) * self.factor,
x.size(6) * self.factor,
)
x = x[:, :, self.factor - 1 :, :, :]
return x
class ConvPixelShuffleUpSampleLayer3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
factor: int,
) -> None:
super().__init__()
self.factor = factor
out_ratio = factor**3
self.conv = CausalConv(
in_channels, out_channels * out_ratio, kernel_size=kernel_size
)
def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
x = self.conv(x, is_init)
x = self.pixel_shuffle_3d(x, self.factor)
return x
@staticmethod
def pixel_shuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor:
batch_size, channels, depth, height, width = x.size()
new_channels = channels // (factor**3)
new_depth = depth * factor
new_height = height * factor
new_width = width * factor
x = x.view(
batch_size, new_channels, factor, factor, factor, depth, height, width
)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(batch_size, new_channels, new_depth, new_height, new_width)
x = x[:, :, factor - 1 :, :, :]
return x
class ConvPixelUnshuffleDownSampleLayer3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
factor: int,
) -> None:
super().__init__()
self.factor = factor
out_ratio = factor**3
assert out_channels % out_ratio == 0
self.conv = CausalConv(
in_channels, out_channels // out_ratio, kernel_size=kernel_size
)
def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
x = self.conv(x, is_init)
x = self.pixel_unshuffle_3d(x, self.factor)
return x
@staticmethod
def pixel_unshuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor:
pad = (0, 0, 0, 0, factor - 1, 0) # (left, right, top, bottom, front, back)
x = F.pad(x, pad)
B, C, D, H, W = x.shape
x = x.view(B, C, D // factor, factor, H // factor, factor, W // factor, factor)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(B, C * factor**3, D // factor, H // factor, W // factor)
return x
class PixelUnshuffleChannelAveragingDownSampleLayer3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor: int,
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor = factor
assert in_channels * factor**3 % out_channels == 0
self.group_size = in_channels * factor**3 // out_channels
def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor:
pad = (
0,
0,
0,
0,
self.factor - 1,
0,
) # (left, right, top, bottom, front, back)
x = F.pad(x, pad)
B, C, D, H, W = x.shape
x = x.view(
B,
C,
D // self.factor,
self.factor,
H // self.factor,
self.factor,
W // self.factor,
self.factor,
)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(
B, C * self.factor**3, D // self.factor, H // self.factor, W // self.factor
)
x = x.view(
B,
self.out_channels,
self.group_size,
D // self.factor,
H // self.factor,
W // self.factor,
)
x = x.mean(dim=2)
return x
def base_group_norm_with_zero_pad(
x, norm_layer, act_silu=True, pad_size=2
) -> torch.Tensor:
out_shape = list(x.shape)
out_shape[1] += pad_size
out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
out[:, pad_size:] = base_group_norm(
x, norm_layer, act_silu=act_silu, channel_last=True
)
out[:, :pad_size] = 0
return out
class CausalConvChannelLast(CausalConv):
time_causal_padding: tuple[Any, ...]
time_uncausal_padding: tuple[Any, ...]
def __init__(self, chan_in, chan_out, kernel_size, **kwargs) -> None:
super().__init__(chan_in, chan_out, kernel_size, **kwargs)
self.time_causal_padding = (0, 0) + self.time_causal_padding
self.time_uncausal_padding = (0, 0) + self.time_uncausal_padding
def forward(self, x, is_init=True, residual=None) -> torch.Tensor:
if self.is_first_run:
self.is_first_run = False
# self.conv.weight = nn.Parameter(self.conv.weight.permute(0,2,3,4,1).contiguous())
x = nn.functional.pad(
x, self.time_causal_padding if is_init else self.time_uncausal_padding
)
x = base_conv3d_channel_last(x, self.conv, residual=residual)
return x
class CausalConvAfterNorm(CausalConv):
def __init__(self, chan_in, chan_out, kernel_size, **kwargs) -> None:
super().__init__(chan_in, chan_out, kernel_size, **kwargs)
if self.time_causal_padding == (1, 1, 1, 1, 2, 0):
self.conv = nn.Conv3d(
chan_in,
chan_out,
kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=(0, 1, 1),
**kwargs,
)
else:
self.conv = nn.Conv3d(
chan_in,
chan_out,
kernel_size,
stride=self.stride,
dilation=self.dilation,
**kwargs,
)
self.is_first_run = True
def forward(self, x, is_init=True, residual=None) -> torch.Tensor:
if self.is_first_run:
self.is_first_run = False
if self.time_causal_padding == (1, 1, 1, 1, 2, 0):
pass
else:
x = nn.functional.pad(x, self.time_causal_padding).contiguous()
x = base_conv3d_channel_last(x, self.conv, residual=residual)
return x
class AttnBlock(nn.Module):
def __init__(self, in_channels) -> None:
super().__init__()
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels)
self.q = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
self.k = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
self.v = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
self.proj_out = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
def attention(self, x, is_init=True) -> torch.Tensor:
x = base_group_norm(x, self.norm, act_silu=False, channel_last=True)
q = self.q(x, is_init)
k = self.k(x, is_init)
v = self.v(x, is_init)
b, t, h, w, c = q.shape
q, k, v = map(lambda x: rearrange(x, "b t h w c -> b 1 (t h w) c"), (q, k, v))
x = nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
x = rearrange(x, "b 1 (t h w) c -> b t h w c", t=t, h=h, w=w)
return x
def forward(self, x):
x = x.permute(0, 2, 3, 4, 1).contiguous()
h = self.attention(x)
x = self.proj_out(h, residual=x)
x = x.permute(0, 4, 1, 2, 3)
return x
class Resnet3DBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
temb_channels=512,
conv_shortcut=False,
) -> None:
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels)
self.conv1 = CausalConvAfterNorm(in_channels, out_channels, kernel_size=3)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels)
self.conv2 = CausalConvAfterNorm(out_channels, out_channels, kernel_size=3)
assert conv_shortcut is False
self.use_conv_shortcut = conv_shortcut
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = CausalConvAfterNorm(
in_channels, out_channels, kernel_size=3
)
else:
self.nin_shortcut = CausalConvAfterNorm(
in_channels, out_channels, kernel_size=1
)
def forward(self, x, temb=None, is_init=True) -> torch.Tensor:
x = x.permute(0, 2, 3, 4, 1).contiguous()
h = base_group_norm_with_zero_pad(x, self.norm1, act_silu=True, pad_size=2)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nn.functional.silu(temb))[:, :, None, None]
x = self.nin_shortcut(x) if self.in_channels != self.out_channels else x
h = base_group_norm_with_zero_pad(h, self.norm2, act_silu=True, pad_size=2)
x = self.conv2(h, residual=x)
x = x.permute(0, 4, 1, 2, 3)
return x
class Downsample3D(nn.Module):
def __init__(self, in_channels, with_conv, stride) -> None:
super().__init__()
self.with_conv = with_conv
if with_conv:
self.conv = CausalConv(
in_channels, in_channels, kernel_size=3, stride=stride
)
def forward(self, x, is_init=True) -> torch.Tensor:
if self.with_conv:
x = self.conv(x, is_init)
else:
x = nn.functional.avg_pool3d(x, kernel_size=2, stride=2)
return x
class VideoEncoder(nn.Module):
def __init__(
self,
ch=32,
ch_mult=(4, 8, 16, 16),
num_res_blocks=2,
in_channels=3,
z_channels=16,
double_z=True,
down_sampling_layer=(1, 2),
resamp_with_conv=True,
version=1,
) -> None:
super().__init__()
temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
# downsampling
self.conv_in = CausalConv(in_channels, ch, kernel_size=3)
self.down_sampling_layer = down_sampling_layer
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
Resnet3DBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=temb_ch,
)
)
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level in self.down_sampling_layer:
down.downsample = Downsample3D(
block_in, resamp_with_conv, stride=(2, 2, 2)
)
else:
down.downsample = Downsample2D(
block_in, resamp_with_conv, padding=0
) # DIFF
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = Resnet3DBlock(
in_channels=block_in, out_channels=block_in, temb_channels=temb_ch
)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = Resnet3DBlock(
in_channels=block_in, out_channels=block_in, temb_channels=temb_ch
)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in)
self.version = version
if version == 2:
channels = 4 * z_channels * 2**3
self.conv_patchify = ConvPixelUnshuffleDownSampleLayer3D(
block_in, channels, kernel_size=3, factor=2
)
self.shortcut_pathify = PixelUnshuffleChannelAveragingDownSampleLayer3D(
block_in, channels, 2
)
self.shortcut_out = PixelUnshuffleChannelAveragingDownSampleLayer3D(
channels, 2 * z_channels if double_z else z_channels, 1
)
self.conv_out = CausalConvChannelLast(
channels, 2 * z_channels if double_z else z_channels, kernel_size=3
)
else:
self.conv_out = CausalConvAfterNorm(
block_in, 2 * z_channels if double_z else z_channels, kernel_size=3
)
@torch.inference_mode()
def forward(self, x, video_frame_num, is_init=True) -> torch.Tensor:
# timestep embedding
temb = None
t = video_frame_num
# downsampling
h = self.conv_in(x, is_init)
# make it real channel last, but behave like normal layout
h = h.permute(0, 2, 3, 4, 1).contiguous().permute(0, 4, 1, 2, 3)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb, is_init)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
if isinstance(self.down[i_level].downsample, Downsample2D):
_, _, t, _, _ = h.shape
h = rearrange(h, "b c t h w -> (b t) h w c", t=t)
h = self.down[i_level].downsample(h)
h = rearrange(h, "(b t) h w c -> b c t h w", t=t)
else:
h = self.down[i_level].downsample(h, is_init)
h = self.mid.block_1(h, temb, is_init)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb, is_init)
h = h.permute(0, 2, 3, 4, 1).contiguous() # b c l h w -> b l h w c
if self.version == 2:
h = base_group_norm(h, self.norm_out, act_silu=True, channel_last=True)
h = h.permute(0, 4, 1, 2, 3).contiguous()
shortcut = self.shortcut_pathify(h, is_init)
h = self.conv_patchify(h, is_init)
h = h.add_(shortcut)
shortcut = self.shortcut_out(h, is_init).permute(0, 2, 3, 4, 1)
h = self.conv_out(h.permute(0, 2, 3, 4, 1).contiguous(), is_init)
h = h.add_(shortcut)
else:
h = base_group_norm_with_zero_pad(
h, self.norm_out, act_silu=True, pad_size=2
)
h = self.conv_out(h, is_init)
h = h.permute(0, 4, 1, 2, 3) # b l h w c -> b c l h w
h = rearrange(h, "b c t h w -> b t c h w")
return h
class Res3DBlockUpsample(nn.Module):
def __init__(
self, input_filters, num_filters, down_sampling_stride, down_sampling=False
) -> None:
super().__init__()
self.input_filters = input_filters
self.num_filters = num_filters
self.act_ = nn.SiLU(inplace=True)
self.conv1 = CausalConvChannelLast(
num_filters, num_filters, kernel_size=[3, 3, 3]
)
self.norm1 = nn.GroupNorm(32, num_filters)
self.conv2 = CausalConvChannelLast(
num_filters, num_filters, kernel_size=[3, 3, 3]
)
self.norm2 = nn.GroupNorm(32, num_filters)
self.down_sampling = down_sampling
if down_sampling:
self.down_sampling_stride = down_sampling_stride
else:
self.down_sampling_stride = [1, 1, 1]
if num_filters != input_filters or down_sampling:
self.conv3 = CausalConvChannelLast(
input_filters,
num_filters,
kernel_size=[1, 1, 1],
stride=self.down_sampling_stride,
)
self.norm3 = nn.GroupNorm(32, num_filters)
def forward(self, x, is_init=False) -> torch.Tensor:
x = x.permute(0, 2, 3, 4, 1).contiguous()
residual = x
h = self.conv1(x, is_init)
h = base_group_norm(h, self.norm1, act_silu=True, channel_last=True)
h = self.conv2(h, is_init)
h = base_group_norm(h, self.norm2, act_silu=False, channel_last=True)
if self.down_sampling or self.num_filters != self.input_filters:
x = self.conv3(x, is_init)
x = base_group_norm(x, self.norm3, act_silu=False, channel_last=True)
h.add_(x)
h = self.act_(h)
if residual is not None:
h.add_(residual)
h = h.permute(0, 4, 1, 2, 3)
return h
class Upsample3D(nn.Module):
def __init__(self, in_channels, scale_factor=2) -> None:
super().__init__()
self.scale_factor = scale_factor
self.conv3d = Res3DBlockUpsample(
input_filters=in_channels,
num_filters=in_channels,
down_sampling_stride=(1, 1, 1),
down_sampling=False,
)
def forward(self, x, is_init=True, is_split=True) -> torch.Tensor:
b, c, t, h, w = x.shape
# x = x.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3).to(memory_format=torch.channels_last_3d)
if is_split:
split_size = c // 8
x_slices = torch.split(x, split_size, dim=1)
x = [
nn.functional.interpolate(x, scale_factor=self.scale_factor)
for x in x_slices
]
x = torch.cat(x, dim=1)
else:
x = nn.functional.interpolate(x, scale_factor=self.scale_factor)
x = self.conv3d(x, is_init)
return x
class VideoDecoder(nn.Module):
def __init__(
self,
ch=128,
z_channels=16,
out_channels=3,
ch_mult=(1, 2, 4, 4),
num_res_blocks=2,
temporal_up_layers=(2, 3),
temporal_downsample=4,
resamp_with_conv=True,
version=1,
) -> None:
super().__init__()
temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.temporal_downsample = temporal_downsample
block_in = ch * ch_mult[self.num_resolutions - 1]
self.version = version
if version == 2:
channels = 4 * z_channels * 2**3
self.conv_in = CausalConv(z_channels, channels, kernel_size=3)
self.shortcut_in = ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(
z_channels, channels, 1
)
self.conv_unpatchify = ConvPixelShuffleUpSampleLayer3D(
channels, block_in, kernel_size=3, factor=2
)
self.shortcut_unpathify = ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(
channels, block_in, 2
)
else:
self.conv_in = CausalConv(z_channels, block_in, kernel_size=3)
# middle
self.mid = nn.Module()
self.mid.block_1 = Resnet3DBlock(
in_channels=block_in, out_channels=block_in, temb_channels=temb_ch
)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = Resnet3DBlock(
in_channels=block_in, out_channels=block_in, temb_channels=temb_ch
)
# upsampling
self.up_id = len(temporal_up_layers)
self.video_frame_num = 1
self.cur_video_frame_num = self.video_frame_num // 2**self.up_id + 1
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
Resnet3DBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=temb_ch,
)
)
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level in temporal_up_layers:
up.upsample = Upsample3D(block_in)
self.cur_video_frame_num = self.cur_video_frame_num * 2
else:
up.upsample = Upsample2D(block_in, resamp_with_conv)
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in)
self.conv_out = CausalConvAfterNorm(block_in, out_channels, kernel_size=3)
@torch.inference_mode()
def forward(self, z, is_init=True) -> torch.Tensor:
z = rearrange(z, "b t c h w -> b c t h w")
h = self.conv_in(z, is_init=is_init)
if self.version == 2:
shortcut = self.shortcut_in(z, is_init=is_init)
h = h.add_(shortcut)
shortcut = self.shortcut_unpathify(h, is_init=is_init)
h = self.conv_unpatchify(h, is_init=is_init)
h = h.add_(shortcut)
temb = None
h = h.permute(0, 2, 3, 4, 1).contiguous().permute(0, 4, 1, 2, 3)
h = self.mid.block_1(h, temb, is_init=is_init)
h = self.mid.attn_1(h)
h = h.permute(0, 2, 3, 4, 1).contiguous().permute(0, 4, 1, 2, 3)
h = self.mid.block_2(h, temb, is_init=is_init)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = h.permute(0, 2, 3, 4, 1).contiguous().permute(0, 4, 1, 2, 3)
h = self.up[i_level].block[i_block](h, temb, is_init=is_init)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
if isinstance(self.up[i_level].upsample, Upsample2D):
B = h.size(0)
h = h.permute(0, 2, 3, 4, 1).flatten(0, 1)
h = self.up[i_level].upsample(h)
h = h.unflatten(0, (B, -1)).permute(0, 4, 1, 2, 3)
else:
h = self.up[i_level].upsample(h, is_init=is_init)
# end
h = h.permute(0, 2, 3, 4, 1) # b c l h w -> b l h w c
h = base_group_norm_with_zero_pad(h, self.norm_out, act_silu=True, pad_size=2)
h = self.conv_out(h)
h = h.permute(0, 4, 1, 2, 3)
if is_init:
h = h[:, :, (self.temporal_downsample - 1) :]
return h
def rms_norm(input, normalized_shape, eps=1e-6) -> torch.Tensor:
dtype = input.dtype
input = input.to(torch.float32)
variance = (
input.pow(2)
.flatten(-len(normalized_shape))
.mean(-1)[(...,) + (None,) * len(normalized_shape)]
)
input = input * torch.rsqrt(variance + eps)
return input.to(dtype)
class DiagonalGaussianDistribution:
def __init__(
self,
parameters,
deterministic=False,
rms_norm_mean=False,
only_return_mean=False,
) -> None:
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=-3) # N,[X],C,H,W
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
self.deterministic = deterministic
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
if rms_norm_mean:
self.mean = rms_norm(self.mean, self.mean.size()[1:])
self.only_return_mean = only_return_mean
def sample(self, generator=None) -> torch.Tensor:
# make sure sample is on the same device
# as the parameters and has same dtype
sample = torch.randn(
self.mean.shape, generator=generator, device=self.parameters.device
)
sample = sample.to(dtype=self.parameters.dtype)
x = self.mean + self.std * sample
if self.only_return_mean:
return self.mean
else:
return x
class AutoencoderKLStepvideo(nn.Module, ParallelTiledVAE):
def __init__(
self,
config: StepVideoVAEConfig,
) -> None:
nn.Module.__init__(self)
ParallelTiledVAE.__init__(self, config)
self.frame_len = config.frame_len
if config.version == 2:
self.latent_len = 3
base_group_norm.spatial = True # type: ignore[attr-defined]
else:
self.latent_len = 5
base_group_norm.spatial = False # type: ignore[attr-defined]
self.encoder = VideoEncoder(
in_channels=config.in_channels,
z_channels=config.z_channels,
num_res_blocks=config.num_res_blocks,
version=config.version,
)
self.decoder = VideoDecoder(
z_channels=config.z_channels,
out_channels=config.out_channels,
num_res_blocks=config.num_res_blocks,
version=config.version,
)
self.world_size = config.world_size
# self.is_init = True
def load_state_dict(self, state_dict, strict=True):
remapped = {}
for key, value in state_dict.items():
if key.startswith("decoder.conv_out."):
# move “decoder.conv_out.weight” → “decoder.conv_out.conv.weight”
suffix = key[len("decoder.conv_out.") :]
remapped[f"decoder.conv_out.conv.{suffix}"] = value
else:
remapped[key] = value
super().load_state_dict(remapped, strict=strict)
def _encode(self, x, is_init_image=True) -> torch.Tensor:
# b, len, c, h, w = x.size()
b, c, len, h, w = x.size()
# x = rearrange(x, 'b l c h w -> b c l h w').contiguous()
z = self.encoder(x, len, True) # 下采样[1, 4, 8, 16, 16]
return z
@torch.inference_mode()
def encode(self, x):
# b (nc cf) c h w -> (b nc) cf c h w -> encode -> (b nc) cf c h w -> b (nc cf) c h w
chunks = list(x.split(self.frame_len, dim=1))
for i in range(len(chunks)):
chunks[i] = self._encode(chunks[i], True)
z = torch.cat(chunks, dim=1)
posterior = DiagonalGaussianDistribution(z)
return posterior.sample()
def _decode(self, z) -> torch.Tensor:
chunks = list(z.split(self.latent_len, dim=2))
for i in range(len(chunks)):
chunks[i] = chunks[i].permute(0, 2, 1, 3, 4)
chunks[i] = chunks[i].to(next(self.decoder.parameters()).dtype)
chunks[i] = self.decoder(chunks[i], is_init=True)
x = torch.cat(chunks, dim=2)
return x
def decode(self, z) -> torch.Tensor:
num_frames = z.size(2)
dec = ParallelTiledVAE.decode(self, z).permute(0, 2, 1, 3, 4)
dec = self.mix(dec).permute(0, 2, 1, 3, 4)
num_sample_frames = num_frames // 3 * 17
return dec[:, :, :num_sample_frames]
def mix(self, x) -> torch.Tensor:
remain_scale = 0.6
mix_scale = 1.0 - remain_scale
front = slice(self.frame_len - 1, x.size(1) - 1, self.frame_len)
back = slice(self.frame_len, x.size(1), self.frame_len)
x[:, back] = x[:, back] * remain_scale + x[:, front] * mix_scale
x[:, front] = x[:, front] * remain_scale + x[:, back] * mix_scale
return x
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
generator: torch.Generator | None = None,
) -> torch.Tensor:
"""
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
return dec
EntryClass = AutoencoderKLStepvideo
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextvars
from contextlib import contextmanager
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from sglang.multimodal_gen.configs.models.vaes import WanVAEConfig
from sglang.multimodal_gen.runtime.layers.activation import get_act_fn
from sglang.multimodal_gen.runtime.models.vaes.common import (
DiagonalGaussianDistribution,
ParallelTiledVAE,
)
from sglang.multimodal_gen.runtime.platforms import current_platform
CACHE_T = 2
is_first_frame = contextvars.ContextVar("is_first_frame", default=False)
feat_cache = contextvars.ContextVar("feat_cache", default=None)
feat_idx = contextvars.ContextVar("feat_idx", default=0)
first_chunk = contextvars.ContextVar("first_chunk", default=None)
@contextmanager
def forward_context(
first_frame_arg=False, feat_cache_arg=None, feat_idx_arg=None, first_chunk_arg=None
):
is_first_frame_token = is_first_frame.set(first_frame_arg)
feat_cache_token = feat_cache.set(feat_cache_arg)
feat_idx_token = feat_idx.set(feat_idx_arg)
first_chunk_token = first_chunk.set(first_chunk_arg)
try:
yield
finally:
is_first_frame.reset(is_first_frame_token)
feat_cache.reset(feat_cache_token)
feat_idx.reset(feat_idx_token)
first_chunk.reset(first_chunk_token)
class AvgDown3D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
pad = (0, 0, 0, 0, pad_t, 0)
x = F.pad(x, pad)
B, C, T, H, W = x.shape
x = x.view(
B,
C,
T // self.factor_t,
self.factor_t,
H // self.factor_s,
self.factor_s,
W // self.factor_s,
self.factor_s,
)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
x = x.view(
B,
C * self.factor,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.view(
B,
self.out_channels,
self.group_size,
T // self.factor_t,
H // self.factor_s,
W // self.factor_s,
)
x = x.mean(dim=2)
return x
class DupUp3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor_t,
factor_s=1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = self.factor_t * self.factor_s * self.factor_s
assert out_channels * self.factor % in_channels == 0
self.repeats = out_channels * self.factor // in_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = x.view(
x.size(0),
self.out_channels,
self.factor_t,
self.factor_s,
self.factor_s,
x.size(2),
x.size(3),
x.size(4),
)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(
x.size(0),
self.out_channels,
x.size(2) * self.factor_t,
x.size(4) * self.factor_s,
x.size(6) * self.factor_s,
)
_first_chunk = first_chunk.get()
if _first_chunk:
x = x[:, :, self.factor_t - 1 :, :, :]
return x
class WanCausalConv3d(nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
caching for efficient inference.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int, int],
stride: int | tuple[int, int, int] = 1,
padding: int | tuple[int, int, int] = 0,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
self.padding: tuple[int, int, int]
# Set up causal padding
self._padding: tuple[int, ...] = (
self.padding[2],
self.padding[2],
self.padding[1],
self.padding[1],
2 * self.padding[0],
0,
)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
x = (
x.to(self.weight.dtype) if current_platform.is_mps() else x
) # casting needed for mps since amp isn't supported
return super().forward(x)
class WanRMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(
self,
dim: int,
channel_first: bool = True,
images: bool = True,
bias: bool = False,
) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return (
F.normalize(x, dim=(1 if self.channel_first else -1))
* self.scale
* self.gamma
+ self.bias
)
class WanUpsample(nn.Upsample):
r"""
Perform upsampling while ensuring the output tensor has the same data type as the input.
Args:
x (torch.Tensor): Input tensor to be upsampled.
Returns:
torch.Tensor: Upsampled tensor with the same data type as the input.
"""
def forward(self, x):
return super().forward(x.float()).type_as(x)
class WanResample(nn.Module):
r"""
A custom resampling module for 2D and 3D data.
Args:
dim (int): The number of input/output channels.
mode (str): The resampling mode. Must be one of:
- 'none': No resampling (identity operation).
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
- 'downsample2d': 2D downsampling with zero-padding and convolution.
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
super().__init__()
self.dim = dim
self.mode = mode
# default to dim //2
if upsample_out_dim is None:
upsample_out_dim = dim // 2
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
)
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
elif mode == "downsample3d":
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
)
self.time_conv = WanCausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
else:
self.resample = nn.Identity()
def forward(self, x):
b, c, t, h, w = x.size()
first_frame = is_first_frame.get()
if first_frame:
assert t == 1
_feat_cache = feat_cache.get()
_feat_idx = feat_idx.get()
if self.mode == "upsample3d":
if _feat_cache is not None:
idx = _feat_idx
if _feat_cache[idx] is None:
_feat_cache[idx] = "Rep"
_feat_idx += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if (
cache_x.shape[2] < 2
and _feat_cache[idx] is not None
and _feat_cache[idx] != "Rep"
):
# cache last frame of last two chunk
cache_x = torch.cat(
[
_feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
if (
cache_x.shape[2] < 2
and _feat_cache[idx] is not None
and _feat_cache[idx] == "Rep"
):
cache_x = torch.cat(
[torch.zeros_like(cache_x).to(cache_x.device), cache_x],
dim=2,
)
if _feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, _feat_cache[idx])
_feat_cache[idx] = cache_x
_feat_idx += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
feat_cache.set(_feat_cache)
feat_idx.set(_feat_idx)
elif not first_frame and hasattr(self, "time_conv"):
x = self.time_conv(x)
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.resample(x)
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
_feat_cache = feat_cache.get()
_feat_idx = feat_idx.get()
if self.mode == "downsample3d":
if _feat_cache is not None:
idx = _feat_idx
if _feat_cache[idx] is None:
_feat_cache[idx] = x.clone()
_feat_idx += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([_feat_cache[idx][:, :, -1:, :, :], x], 2)
)
_feat_cache[idx] = cache_x
_feat_idx += 1
feat_cache.set(_feat_cache)
feat_idx.set(_feat_idx)
elif not first_frame and hasattr(self, "time_conv"):
x = self.time_conv(x)
return x
class WanResidualBlock(nn.Module):
r"""
A custom residual block module.
Args:
in_dim (int): Number of input channels.
out_dim (int): Number of output channels.
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def __init__(
self,
in_dim: int,
out_dim: int,
dropout: float = 0.0,
non_linearity: str = "silu",
) -> None:
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.nonlinearity = get_act_fn(non_linearity)
# layers
self.norm1 = WanRMS_norm(in_dim, images=False)
self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
self.norm2 = WanRMS_norm(out_dim, images=False)
self.dropout = nn.Dropout(dropout)
self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)
self.conv_shortcut = (
WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
)
def forward(self, x):
# Apply shortcut connection
h = self.conv_shortcut(x)
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
_feat_cache = feat_cache.get()
_feat_idx = feat_idx.get()
if _feat_cache is not None:
idx = _feat_idx
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and _feat_cache[idx] is not None:
cache_x = torch.cat(
[
_feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv1(x, _feat_cache[idx])
_feat_cache[idx] = cache_x
_feat_idx += 1
feat_cache.set(_feat_cache)
feat_idx.set(_feat_idx)
else:
x = self.conv1(x)
# Second normalization and activation
x = self.norm2(x)
x = self.nonlinearity(x)
# Dropout
x = self.dropout(x)
_feat_cache = feat_cache.get()
_feat_idx = feat_idx.get()
if _feat_cache is not None:
idx = _feat_idx
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and _feat_cache[idx] is not None:
cache_x = torch.cat(
[
_feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv2(x, _feat_cache[idx])
_feat_cache[idx] = cache_x
_feat_idx += 1
feat_cache.set(_feat_cache)
feat_idx.set(_feat_idx)
else:
x = self.conv2(x)
# Add residual connection
return x + h
class WanAttentionBlock(nn.Module):
r"""
Causal self-attention with a single head.
Args:
dim (int): The number of channels in the input tensor.
"""
def __init__(self, dim) -> None:
super().__init__()
self.dim = dim
# layers
self.norm = WanRMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
def forward(self, x):
identity = x
batch_size, channels, time, height, width = x.size()
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
x = self.norm(x)
# compute query, key, value
qkv = self.to_qkv(x)
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
qkv = qkv.permute(0, 1, 3, 2).contiguous()
q, k, v = qkv.chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(q, k, v)
x = (
x.squeeze(1)
.permute(0, 2, 1)
.reshape(batch_size * time, channels, height, width)
)
# output projection
x = self.proj(x)
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
x = x.view(batch_size, time, channels, height, width)
x = x.permute(0, 2, 1, 3, 4)
return x + identity
class WanMidBlock(nn.Module):
"""
Middle block for WanVAE encoder and decoder.
Args:
dim (int): Number of input/output channels.
dropout (float): Dropout rate.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim: int,
dropout: float = 0.0,
non_linearity: str = "silu",
num_layers: int = 1,
):
super().__init__()
self.dim = dim
# Create the components
resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
attentions = []
for _ in range(num_layers):
attentions.append(WanAttentionBlock(dim))
resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, x):
# First residual block
x = self.resnets[0](x)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:], strict=True):
if attn is not None:
x = attn(x)
x = resnet(x)
return x
class WanResidualDownBlock(nn.Module):
def __init__(
self,
in_dim,
out_dim,
dropout,
num_res_blocks,
temperal_downsample=False,
down_flag=False,
):
super().__init__()
# Shortcut path with downsample
self.avg_shortcut = AvgDown3D(
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path with residual blocks and downsample
resnets = []
for _ in range(num_res_blocks):
resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
in_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add the final downsample block
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
self.downsampler = WanResample(out_dim, mode=mode)
else:
self.downsampler = None
def forward(self, x):
x_copy = x.clone()
for resnet in self.resnets:
x = resnet(x)
if self.downsampler is not None:
x = self.downsampler(x)
return x + self.avg_shortcut(x_copy)
class WanEncoder3d(nn.Module):
r"""
A 3D encoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_downsample (list of bool): Whether to downsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
in_channels: int = 3,
dim=128,
z_dim=4,
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
attn_scales=(),
temperal_downsample=(True, True, False),
dropout=0.0,
non_linearity: str = "silu",
is_residual: bool = False, # wan 2.2 vae use a residual downblock
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
dim_mult = list(dim_mult)
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = list(attn_scales)
self.temperal_downsample = list(temperal_downsample)
self.nonlinearity = get_act_fn(non_linearity)
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:], strict=True)):
# residual (+attention) blocks
if is_residual:
self.down_blocks.append(
WanResidualDownBlock(
in_dim,
out_dim,
dropout,
num_res_blocks,
temperal_downsample=(
temperal_downsample[i] if i != len(dim_mult) - 1 else False
),
down_flag=i != len(dim_mult) - 1,
)
)
else:
for _ in range(num_res_blocks):
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(WanAttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(WanResample(out_dim, mode=mode))
scale /= 2.0
# middle blocks
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x):
_feat_cache = feat_cache.get()
_feat_idx = feat_idx.get()
if _feat_cache is not None:
idx = _feat_idx
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and _feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
_feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv_in(x, _feat_cache[idx])
_feat_cache[idx] = cache_x
_feat_idx += 1
feat_cache.set(_feat_cache)
feat_idx.set(_feat_idx)
else:
x = self.conv_in(x)
## downsamples
for layer in self.down_blocks:
x = layer(x)
## middle
x = self.mid_block(x)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
_feat_cache = feat_cache.get()
_feat_idx = feat_idx.get()
if _feat_cache is not None:
idx = _feat_idx
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and _feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
_feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv_out(x, _feat_cache[idx])
_feat_cache[idx] = cache_x
_feat_idx += 1
feat_cache.set(_feat_cache)
feat_idx.set(_feat_idx)
else:
x = self.conv_out(x)
return x
# adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
class WanResidualUpBlock(nn.Module):
"""
A block that handles upsampling for the WanVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
temperal_upsample (bool): Whether to upsample on temporal dimension
up_flag (bool): Whether to upsample or not
non_linearity (str): Type of non-linearity to use
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_res_blocks: int,
dropout: float = 0.0,
temperal_upsample: bool = False,
up_flag: bool = False,
non_linearity: str = "silu",
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2,
)
else:
self.avg_shortcut = None
# create residual blocks
resnets = []
current_dim = in_dim
for _ in range(num_res_blocks + 1):
resnets.append(
WanResidualBlock(current_dim, out_dim, dropout, non_linearity)
)
current_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add upsampling layer if needed
if up_flag:
upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
self.upsampler = WanResample(
out_dim, mode=upsample_mode, upsample_out_dim=out_dim
)
else:
self.upsampler = None
self.gradient_checkpointing = False
def forward(self, x):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
if self.avg_shortcut is not None:
x_copy = x.clone()
for resnet in self.resnets:
x = resnet(x)
if self.upsampler is not None:
x = self.upsampler(x)
if self.avg_shortcut is not None:
x = x + self.avg_shortcut(x_copy)
return x
class WanUpBlock(nn.Module):
"""
A block that handles upsampling for the WanVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
non_linearity (str): Type of non-linearity to use
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_res_blocks: int,
dropout: float = 0.0,
upsample_mode: str | None = None,
non_linearity: str = "silu",
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# Create layers list
resnets = []
# Add residual blocks and attention if needed
current_dim = in_dim
for _ in range(num_res_blocks + 1):
resnets.append(
WanResidualBlock(current_dim, out_dim, dropout, non_linearity)
)
current_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add upsampling layer if needed
self.upsamplers = None
if upsample_mode is not None:
self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
self.gradient_checkpointing = False
def forward(self, x):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
for resnet in self.resnets:
x = resnet(x)
if self.upsamplers is not None:
x = self.upsamplers[0](x)
return x
class WanDecoder3d(nn.Module):
r"""
A 3D decoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_upsample (list of bool): Whether to upsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
attn_scales=(),
temperal_upsample=(False, True, True),
dropout=0.0,
non_linearity: str = "silu",
out_channels: int = 3,
is_residual: bool = False,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
dim_mult = list(dim_mult)
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = list(attn_scales)
self.temperal_upsample = list(temperal_upsample)
self.nonlinearity = get_act_fn(non_linearity)
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
# init block
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
# upsample blocks
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:], strict=True)):
# residual (+attention) blocks
if i > 0 and not is_residual:
# wan vae 2.1
in_dim = in_dim // 2
# determine if we need upsampling
up_flag = i != len(dim_mult) - 1
# determine upsampling mode, if not upsampling, set to None
upsample_mode = None
if up_flag and temperal_upsample[i]:
upsample_mode = "upsample3d"
elif up_flag:
upsample_mode = "upsample2d"
# Create and add the upsampling block
if is_residual:
up_block = WanResidualUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
temperal_upsample=temperal_upsample[i] if up_flag else False,
up_flag=up_flag,
non_linearity=non_linearity,
)
else:
up_block = WanUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
self.up_blocks.append(up_block)
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x):
## conv1
_feat_cache = feat_cache.get()
_feat_idx = feat_idx.get()
if _feat_cache is not None:
idx = _feat_idx
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and _feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
_feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv_in(x, _feat_cache[idx])
_feat_cache[idx] = cache_x
_feat_idx += 1
feat_cache.set(_feat_cache)
feat_idx.set(_feat_idx)
else:
x = self.conv_in(x)
## middle
x = self.mid_block(x)
## upsamples
for up_block in self.up_blocks:
x = up_block(x)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
_feat_cache = feat_cache.get()
_feat_idx = feat_idx.get()
if _feat_cache is not None:
idx = _feat_idx
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and _feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat(
[
_feat_cache[idx][:, :, -1, :, :]
.unsqueeze(2)
.to(cache_x.device),
cache_x,
],
dim=2,
)
x = self.conv_out(x, _feat_cache[idx])
_feat_cache[idx] = cache_x
_feat_idx += 1
feat_cache.set(_feat_cache)
feat_idx.set(_feat_idx)
else:
x = self.conv_out(x)
return x
def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b c f (h q) (w r) -> b (c r q) f h w",
q=patch_size,
r=patch_size,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 4:
x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
elif x.dim() == 5:
x = rearrange(
x,
"b (c r q) f h w -> b c f (h q) (w r)",
q=patch_size,
r=patch_size,
)
return x
class AutoencoderKLWan(nn.Module, ParallelTiledVAE):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].
"""
_supports_gradient_checkpointing = False
def __init__(
self,
config: WanVAEConfig,
) -> None:
nn.Module.__init__(self)
ParallelTiledVAE.__init__(self, config)
self.z_dim = config.z_dim
self.temperal_downsample = list(config.temperal_downsample)
self.temperal_upsample = list(config.temperal_downsample)[::-1]
if config.decoder_base_dim is None:
decoder_base_dim = config.base_dim
else:
decoder_base_dim = config.decoder_base_dim
self.latents_mean = list(config.latents_mean)
self.latents_std = list(config.latents_std)
self.shift_factor = config.shift_factor
if config.load_encoder:
self.encoder = WanEncoder3d(
in_channels=config.in_channels,
dim=config.base_dim,
z_dim=self.z_dim * 2,
dim_mult=config.dim_mult,
num_res_blocks=config.num_res_blocks,
attn_scales=config.attn_scales,
temperal_downsample=self.temperal_downsample,
dropout=config.dropout,
is_residual=config.is_residual,
)
self.quant_conv = WanCausalConv3d(self.z_dim * 2, self.z_dim * 2, 1)
self.post_quant_conv = WanCausalConv3d(self.z_dim, self.z_dim, 1)
if config.load_decoder:
self.decoder = WanDecoder3d(
dim=decoder_base_dim,
z_dim=self.z_dim,
dim_mult=config.dim_mult,
num_res_blocks=config.num_res_blocks,
attn_scales=config.attn_scales,
temperal_upsample=self.temperal_upsample,
dropout=config.dropout,
out_channels=config.out_channels,
is_residual=config.is_residual,
)
self.use_feature_cache = config.use_feature_cache
def clear_cache(self) -> None:
def _count_conv3d(model) -> int:
count = 0
for m in model.modules():
if isinstance(m, WanCausalConv3d):
count += 1
return count
if self.config.load_decoder:
self._conv_num = _count_conv3d(self.decoder)
self._conv_idx = 0
self._feat_map = [None] * self._conv_num
# cache encode
if self.config.load_encoder:
self._enc_conv_num = _count_conv3d(self.encoder)
self._enc_conv_idx = 0
self._enc_feat_map = [None] * self._enc_conv_num
def encode(self, x: torch.Tensor) -> torch.Tensor:
if self.use_feature_cache:
self.clear_cache()
if self.config.patch_size is not None:
x = patchify(x, patch_size=self.config.patch_size)
with forward_context(
feat_cache_arg=self._enc_feat_map, feat_idx_arg=self._enc_conv_idx
):
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
feat_idx.set(0)
if i == 0:
out = self.encoder(x[:, :, :1, :, :])
else:
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :])
out = torch.cat([out, out_], 2)
enc = self.quant_conv(out)
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
enc = torch.cat([mu, logvar], dim=1)
enc = DiagonalGaussianDistribution(enc)
self.clear_cache()
else:
for block in self.encoder.down_blocks:
if isinstance(block, WanResample) and block.mode == "downsample3d":
_padding = list(block.time_conv._padding)
_padding[4] = 2
block.time_conv._padding = tuple(_padding)
enc = ParallelTiledVAE.encode(self, x)
return enc
def _encode(self, x: torch.Tensor, first_frame=False) -> torch.Tensor:
with forward_context(first_frame_arg=first_frame):
out = self.encoder(x)
enc = self.quant_conv(out)
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
enc = torch.cat([mu, logvar], dim=1)
return enc
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
first_frame = x[:, :, 0, :, :].unsqueeze(2)
first_frame = self._encode(first_frame, first_frame=True)
enc = ParallelTiledVAE.tiled_encode(self, x)
enc = enc[:, :, 1:]
enc = torch.cat([first_frame, enc], dim=2)
return enc
def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
first_frame = x[:, :, 0, :, :].unsqueeze(2)
first_frame = self._encode(first_frame, first_frame=True)
enc = ParallelTiledVAE.spatial_tiled_encode(self, x)
enc = enc[:, :, 1:]
enc = torch.cat([first_frame, enc], dim=2)
return enc
def decode(self, z: torch.Tensor) -> torch.Tensor:
if self.use_feature_cache:
self.clear_cache()
iter_ = z.shape[2]
x = self.post_quant_conv(z)
with forward_context(
feat_cache_arg=self._feat_map, feat_idx_arg=self._conv_idx
):
for i in range(iter_):
feat_idx.set(0)
if i == 0:
first_chunk.set(True)
out = self.decoder(x[:, :, i : i + 1, :, :])
else:
first_chunk.set(False)
out_ = self.decoder(x[:, :, i : i + 1, :, :])
out = torch.cat([out, out_], 2)
if self.config.patch_size is not None:
out = unpatchify(out, patch_size=self.config.patch_size)
out = out.float()
out = torch.clamp(out, min=-1.0, max=1.0)
self.clear_cache()
else:
out = ParallelTiledVAE.decode(self, z)
return out
def _decode(self, z: torch.Tensor, first_frame=False) -> torch.Tensor:
x = self.post_quant_conv(z)
with forward_context(first_frame_arg=first_frame):
out = self.decoder(x)
out = torch.clamp(out, min=-1.0, max=1.0)
return out
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
self.blend_num_frames *= 2
dec = ParallelTiledVAE.tiled_decode(self, z)
start_frame_idx = self.temporal_compression_ratio - 1
dec = dec[:, :, start_frame_idx:]
return dec
def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
dec = ParallelTiledVAE.spatial_tiled_decode(self, z)
start_frame_idx = self.temporal_compression_ratio - 1
dec = dec[:, :, start_frame_idx:]
return dec
def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor:
self.blend_num_frames *= 2
dec = ParallelTiledVAE.parallel_tiled_decode(self, z)
start_frame_idx = self.temporal_compression_ratio - 1
dec = dec[:, :, start_frame_idx:]
return dec
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
generator: torch.Generator | None = None,
) -> torch.Tensor:
"""
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
return dec
EntryClass = AutoencoderKLWan
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
import os
import tempfile
from collections.abc import Callable
from urllib.parse import unquote, urlparse
import imageio
import numpy as np
import PIL.Image
import PIL.ImageOps
import requests
import torch
from packaging import version
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
def pil_to_numpy(images: list[PIL.Image.Image] | PIL.Image.Image) -> np.ndarray:
r"""
Convert a PIL image or a list of PIL images to NumPy arrays.
Args:
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
The PIL image or list of images to convert to NumPy format.
Returns:
`np.ndarray`:
A NumPy array representation of the images.
"""
if not isinstance(images, list):
images = [images]
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
images_arr: np.ndarray = np.stack(images, axis=0)
return images_arr
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
r"""
Convert a NumPy image to a PyTorch tensor.
Args:
images (`np.ndarray`):
The NumPy image array to convert to PyTorch format.
Returns:
`torch.Tensor`:
A PyTorch tensor representation of the images.
"""
if images.ndim == 3:
images = images[..., None]
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
return images
def normalize(images: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
r"""
Normalize an image array to [-1,1].
Args:
images (`np.ndarray` or `torch.Tensor`):
The image array to normalize.
Returns:
`np.ndarray` or `torch.Tensor`:
The normalized image array.
"""
return 2.0 * images - 1.0
# adapted from diffusers.utils import load_image
def load_image(
image: str | PIL.Image.Image,
convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] | None = None,
) -> PIL.Image.Image:
"""
Loads `image` to a PIL Image.
Args:
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
A conversion method to apply to the image after loading it. When set to `None` the image will be converted
"RGB".
Returns:
`PIL.Image.Image`:
A PIL Image.
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = PIL.Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
raise ValueError(
f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
)
image = PIL.ImageOps.exif_transpose(image)
if convert_method is not None:
image = convert_method(image)
else:
image = image.convert("RGB")
return image
# adapted from diffusers.utils import load_video
def load_video(
video: str,
convert_method: (
Callable[[list[PIL.Image.Image]], list[PIL.Image.Image]] | None
) = None,
) -> list[PIL.Image.Image]:
"""
Loads `video` to a list of PIL Image.
Args:
video (`str`):
A URL or Path to a video to convert to a list of PIL Image format.
convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*):
A conversion method to apply to the video after loading it. When set to `None` the images will be converted
to "RGB".
Returns:
`List[PIL.Image.Image]`:
The video as a list of PIL images.
"""
is_url = video.startswith("http://") or video.startswith("https://")
is_file = os.path.isfile(video)
was_tempfile_created = False
if not (is_url or is_file):
raise ValueError(
f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path."
)
if is_url:
response = requests.get(video, stream=True)
if response.status_code != 200:
raise ValueError(
f"Failed to download video. Status code: {response.status_code}"
)
parsed_url = urlparse(video)
file_name = os.path.basename(unquote(parsed_url.path))
suffix = os.path.splitext(file_name)[1] or ".mp4"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file:
video_path = temp_file.name
video_data = response.iter_content(chunk_size=8192)
for chunk in video_data:
temp_file.write(chunk)
video = video_path
pil_images = []
if video.endswith(".gif"):
gif = PIL.Image.open(video)
try:
while True:
pil_images.append(gif.copy())
gif.seek(gif.tell() + 1)
except EOFError:
pass
else:
try:
imageio.plugins.ffmpeg.get_exe()
except AttributeError:
raise AttributeError(
"`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg"
) from None
with imageio.get_reader(video) as reader:
# Read all frames
for frame in reader:
pil_images.append(PIL.Image.fromarray(frame))
if was_tempfile_created:
os.remove(video_path)
if convert_method is not None:
pil_images = convert_method(pil_images)
return pil_images
def get_default_height_width(
image: PIL.Image.Image | np.ndarray | torch.Tensor,
vae_scale_factor: int,
height: int | None = None,
width: int | None = None,
) -> tuple[int, int]:
r"""
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
Args:
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
tensor, it should have shape `[batch, channels, height, width]`.
height (`Optional[int]`, *optional*, defaults to `None`):
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
width (`Optional[int]`, *optional*, defaults to `None`):
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
Returns:
`Tuple[int, int]`:
A tuple containing the height and width, both resized to the nearest integer multiple of
`vae_scale_factor`.
"""
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
else:
height = image.shape[1]
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
else:
width = image.shape[2]
width, height = (
x - x % vae_scale_factor for x in (width, height)
) # resize to integer multiple of vae_scale_factor
return height, width
def resize(
image: PIL.Image.Image | np.ndarray | torch.Tensor,
height: int,
width: int,
resize_mode: str = "default", # "default", "fill", "crop"
resample: str = "lanczos",
) -> PIL.Image.Image | np.ndarray | torch.Tensor:
"""
Resize image.
Args:
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
The image input, can be a PIL image, numpy array or pytorch tensor.
height (`int`):
The height to resize to.
width (`int`):
The width to resize to.
resize_mode (`str`, *optional*, defaults to `default`):
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
supported for PIL image input.
Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
The resized image.
"""
if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
raise ValueError(
f"Only PIL image input is supported for resize_mode {resize_mode}"
)
assert isinstance(image, PIL.Image.Image)
if resize_mode == "default":
image = image.resize((width, height), resample=PIL_INTERPOLATION[resample])
else:
raise ValueError(f"resize_mode {resize_mode} is not supported")
return image
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# Adding a New Custom Pipeline
Please see documentation [here](https://hao-ai-lab.github.io/sgl-diffusion/contributing/add_pipeline.html)
# PipelineStages
Basic components in a pipeline, which can be used by customed pipelines of different models.
The stages form a partial order
# PipelineExecutors
Runs the stages in a pipeline in various way. Supported ways:
1. sync
2. async
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Diffusion pipelines for sglang.multimodal_gen.
This package contains diffusion pipelines for generating videos and images.
"""
from typing import cast
from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.pipelines.lora_pipeline import LoRAPipeline
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.pipeline_registry import (
PipelineType,
get_pipeline_registry,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (
maybe_download_model,
verify_model_config_and_directory,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class PipelineWithLoRA(LoRAPipeline, ComposedPipelineBase):
"""Type for a pipeline that has both ComposedPipelineBase and LoRAPipeline functionality."""
pass
def build_pipeline(
server_args: ServerArgs,
pipeline_type: PipelineType | str = PipelineType.BASIC,
) -> PipelineWithLoRA:
"""
Only works with valid hf diffusers configs. (model_index.json)
We want to build a pipeline based on the inference args mode_path:
1. download the model from the hub if it's not already downloaded
2. verify the model config and directory
3. based on the config, determine the pipeline class
"""
# Get pipeline type
model_path = server_args.model_path
model_path = maybe_download_model(model_path)
# server_args.downloaded_model_path = model_path
logger.info("Model path: %s", model_path)
config = verify_model_config_and_directory(model_path)
pipeline_name = config.get("_class_name")
if pipeline_name is None:
raise ValueError(
"Model config does not contain a _class_name attribute. "
"Only diffusers format is supported."
)
# Get the appropriate pipeline registry based on pipeline_type
logger.info(
"Building pipeline of type: %s",
(
pipeline_type.value
if isinstance(pipeline_type, PipelineType)
else pipeline_type
),
)
pipeline_registry = get_pipeline_registry(pipeline_type)
if isinstance(pipeline_type, str):
pipeline_type = PipelineType.from_string(pipeline_type)
pipeline_cls = pipeline_registry.resolve_pipeline_cls(
pipeline_name, pipeline_type, server_args.workload_type
)
# instantiate the pipelines
pipeline = pipeline_cls(model_path, server_args)
logger.info("Pipelines instantiated")
return cast(PipelineWithLoRA, pipeline)
__all__ = [
"build_pipeline",
"ComposedPipelineBase",
"Req",
"LoRAPipeline",
]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Base class for composed pipelines.
This module defines the base class for pipelines that are composed of multiple stages.
"""
import argparse
import os
from abc import ABC, abstractmethod
from typing import Any, cast
import torch
from tqdm import tqdm
from sglang.multimodal_gen.configs.pipelines import PipelineConfig
from sglang.multimodal_gen.runtime.loader.component_loader import (
PipelineComponentLoader,
)
from sglang.multimodal_gen.runtime.pipelines.executors.pipeline_executor import (
PipelineExecutor,
)
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages import PipelineStage
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import (
maybe_download_model,
verify_model_config_and_directory,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class ComposedPipelineBase(ABC):
"""
Base class for pipelines composed of multiple stages.
This class provides the framework for creating pipelines by composing multiple
stages together. Each stage is responsible for a specific part of the diffusion
process, and the pipeline orchestrates the execution of these stages.
"""
is_video_pipeline: bool = False # To be overridden by video pipelines
# should contains only the modules to be loaded
_required_config_modules: list[str] = []
_extra_config_module_map: dict[str, str] = {}
server_args: ServerArgs | None = None
modules: dict[str, Any] = {}
post_init_called: bool = False
executor: PipelineExecutor | None = None
# the name of the pipeline it associated with, in diffusers
pipeline_name: str
def __init__(
self,
model_path: str,
server_args: ServerArgs,
required_config_modules: list[str] | None = None,
loaded_modules: dict[str, torch.nn.Module] | None = None,
executor: PipelineExecutor | None = None,
):
"""
Initialize the pipeline. After __init__, the pipeline should be ready to
use. The pipeline should be stateless and not hold any batch state.
"""
self.server_args = server_args
self.model_path: str = model_path
self._stages: list[PipelineStage] = []
self._stage_name_mapping: dict[str, PipelineStage] = {}
self.executor = executor or self.build_executor(server_args=server_args)
if required_config_modules is not None:
self._required_config_modules = required_config_modules
if self._required_config_modules is None:
raise NotImplementedError("Subclass must set _required_config_modules")
# temp disable for duplicate initialing tp
# maybe_init_distributed_environment_and_model_parallel(
# server_args.tp_size, server_args.sp_size
# )
# Load modules directly in initialization
logger.info("Loading pipeline modules...")
self.modules = self.load_modules(server_args, loaded_modules)
def build_executor(self, server_args: ServerArgs):
# TODO
from sglang.multimodal_gen.runtime.pipelines.executors.parallel_executor import (
ParallelExecutor,
)
# return SyncExecutor(server_args=server_args)
return ParallelExecutor(server_args=server_args)
def post_init(self) -> None:
assert self.server_args is not None, "server_args must be set"
if self.post_init_called:
return
self.post_init_called = True
self.initialize_pipeline(self.server_args)
if self.server_args.enable_torch_compile:
self.modules["transformer"] = torch.compile(self.modules["transformer"])
logger.info("Torch Compile enabled for DiT")
logger.info("Creating pipeline stages...")
self.create_pipeline_stages(self.server_args)
@classmethod
def from_pretrained(
cls,
model_path: str,
device: str | None = None,
torch_dtype: torch.dtype | None = None,
pipeline_config: str | PipelineConfig | None = None,
args: argparse.Namespace | None = None,
required_config_modules: list[str] | None = None,
loaded_modules: dict[str, torch.nn.Module] | None = None,
**kwargs,
) -> "ComposedPipelineBase":
"""
Load a pipeline from a pretrained model.
loaded_modules: Optional[Dict[str, torch.nn.Module]] = None,
If provided, loaded_modules will be used instead of loading from config/pretrained weights.
"""
kwargs["model_path"] = model_path
server_args = ServerArgs.from_kwargs(**kwargs)
logger.info("server_args in from_pretrained: %s", server_args)
pipe = cls(
model_path,
server_args,
required_config_modules=required_config_modules,
loaded_modules=loaded_modules,
)
pipe.post_init()
return pipe
def get_module(self, module_name: str, default_value: Any = None) -> Any:
if module_name not in self.modules:
return default_value
return self.modules[module_name]
def add_module(self, module_name: str, module: Any):
self.modules[module_name] = module
def _load_config(self) -> dict[str, Any]:
model_path = maybe_download_model(self.model_path)
self.model_path = model_path
# server_args.downloaded_model_path = model_path
logger.info("Model path: %s", model_path)
config = verify_model_config_and_directory(model_path)
return cast(dict[str, Any], config)
@property
def required_config_modules(self) -> list[str]:
"""
List of modules that are required by the pipeline. The names should match
the diffusers directory and model_index.json file. These modules will be
loaded using the PipelineComponentLoader and made available in the
modules dictionary. Access these modules using the get_module method.
class ConcretePipeline(ComposedPipelineBase):
_required_config_modules = ["vae", "text_encoder", "transformer", "scheduler", "tokenizer"]
@property
def required_config_modules(self):
return self._required_config_modules
"""
return self._required_config_modules
@property
def stages(self) -> list[PipelineStage]:
"""
List of stages in the pipeline.
"""
return self._stages
@abstractmethod
def create_pipeline_stages(self, server_args: ServerArgs):
"""
Create the inference pipeline stages.
"""
raise NotImplementedError
def initialize_pipeline(self, server_args: ServerArgs):
"""
Initialize the pipeline.
"""
return
def load_modules(
self,
server_args: ServerArgs,
loaded_modules: dict[str, torch.nn.Module] | None = None,
) -> dict[str, Any]:
"""
Load the modules from the config.
loaded_modules: Optional[Dict[str, torch.nn.Module]] = None,
If provided, loaded_modules will be used instead of loading from config/pretrained weights.
"""
model_index = self._load_config()
logger.info("Loading pipeline modules from config: %s", model_index)
# remove keys that are not pipeline modules
model_index.pop("_class_name")
model_index.pop("_diffusers_version")
if (
"boundary_ratio" in model_index
and model_index["boundary_ratio"] is not None
):
logger.info(
"MoE pipeline detected. Adding transformer_2 to self.required_config_modules..."
)
self.required_config_modules.append("transformer_2")
logger.info(
"MoE pipeline detected. Setting boundary ratio to %s",
model_index["boundary_ratio"],
)
server_args.pipeline_config.dit_config.boundary_ratio = model_index[
"boundary_ratio"
]
model_index.pop("boundary_ratio", None)
# used by Wan2.2 ti2v
model_index.pop("expand_timesteps", None)
# some sanity checks
assert (
len(model_index) > 1
), "model_index.json must contain at least one pipeline module"
model_index = {
required_module: model_index[required_module]
for required_module in self.required_config_modules
}
for module_name in self.required_config_modules:
if (
module_name not in model_index
and module_name in self._extra_config_module_map
):
extra_module_value = self._extra_config_module_map[module_name]
logger.warning(
"model_index.json does not contain a %s module, but found {%s: %s} in _extra_config_module_map, adding to model_index.",
module_name,
module_name,
extra_module_value,
)
if extra_module_value in model_index:
logger.info(
"Using module %s for %s", extra_module_value, module_name
)
model_index[module_name] = model_index[extra_module_value]
continue
else:
raise ValueError(
f"Required module key: {module_name} value: {model_index.get(module_name)} was not found in loaded modules {model_index.keys()}"
)
# all the component models used by the pipeline
required_modules = self.required_config_modules
logger.info("Loading required components: %s", required_modules)
components = {}
for module_name, (
transformers_or_diffusers,
architecture,
) in tqdm(iterable=model_index.items(), desc="Loading required modules"):
if transformers_or_diffusers is None:
logger.warning(
"Module %s in model_index.json has null value, removing from required_config_modules",
module_name,
)
if module_name in self.required_config_modules:
self.required_config_modules.remove(module_name)
continue
if module_name not in required_modules:
logger.info("Skipping module %s", module_name)
continue
if loaded_modules is not None and module_name in loaded_modules:
logger.info("Using module %s already provided", module_name)
components[module_name] = loaded_modules[module_name]
continue
# we load the module from the extra config module map if it exists
if module_name in self._extra_config_module_map:
load_module_name = self._extra_config_module_map[module_name]
else:
load_module_name = module_name
component_model_path = os.path.join(self.model_path, load_module_name)
module = PipelineComponentLoader.load_module(
module_name=load_module_name,
component_model_path=component_model_path,
transformers_or_diffusers=transformers_or_diffusers,
server_args=server_args,
)
logger.info("Loaded module %s from %s", module_name, component_model_path)
if module_name in components:
logger.warning("Overwriting module %s", module_name)
components[module_name] = module
# Check if all required modules were loaded
for module_name in required_modules:
if module_name not in components or components[module_name] is None:
raise ValueError(
f"Required module key: {module_name} value: {components.get(module_name)} was not found in loaded modules {components.keys()}"
)
return components
def add_stage(self, stage_name: str, stage: PipelineStage):
assert self.modules is not None, "No modules are registered"
self._stages.append(stage)
self._stage_name_mapping[stage_name] = stage
setattr(self, stage_name, stage)
# TODO(will): don't hardcode no_grad
@torch.no_grad()
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Generate a video or image using the pipeline.
Args:
batch: The batch to generate from.
server_args: The inference arguments.
Returns:
Req: The batch with the generated video or image.
"""
if not self.post_init_called:
self.post_init()
# Execute each stage
logger.info(
"Running pipeline stages: %s",
list(self._stage_name_mapping.keys()),
main_process_only=True,
)
return self.executor.execute(self.stages, batch, server_args)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
from typing import List
import torch
from sglang.multimodal_gen.runtime.distributed import get_sp_group
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_cfg_group,
get_classifier_free_guidance_rank,
)
from sglang.multimodal_gen.runtime.pipelines import Req
from sglang.multimodal_gen.runtime.pipelines.executors.pipeline_executor import (
PipelineExecutor,
Timer,
)
from sglang.multimodal_gen.runtime.pipelines.stages.base import (
PipelineStage,
StageParallelismType,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.distributed import broadcast_pyobj
class ParallelExecutor(PipelineExecutor):
"""
The correctness of the execution relies on the parallelism_type declared by stages
"""
def collect_from_main(self, batches: list[Req]):
# TODO: fix this condition
if self.server_args.sp_degree != 1:
sp_group = get_sp_group()
batches = broadcast_pyobj(
batches,
sp_group.rank,
sp_group.cpu_group,
src=sp_group.ranks[0],
)
if self.server_args.enable_cfg_parallel:
batches = broadcast_pyobj(
batches,
self.worker.cfg_group.rank,
self.worker.cfg_cpu_group,
src=self.worker.cfg_group.ranks[0],
)
def execute(
self,
stages: List[PipelineStage],
batch: Req,
server_args: ServerArgs,
) -> Req:
rank = get_classifier_free_guidance_rank()
cfg_rank = get_classifier_free_guidance_rank()
cfg_group = get_cfg_group()
# TODO: decide when to gather on main when CFG_PARALLEL -> MAIN_RANK_ONLY
for stage in stages:
with Timer(stage.__class__.__name__):
paradigm = stage.parallelism_type
if paradigm == StageParallelismType.MAIN_RANK_ONLY:
if rank == 0:
batch = stage(batch, server_args)
# obj_list = [batch] if rank == 0 else []
#
# broadcasted_list = broadcast_pyobj(
# obj_list, rank=rank, dist_group=cfg_group.cpu_group, src=0
# )
# if rank != 0:
# batch = broadcasted_list[0]
torch.distributed.barrier()
elif paradigm == StageParallelismType.CFG_PARALLEL:
obj_list = [batch] if rank == 0 else []
broadcasted_list = broadcast_pyobj(
obj_list, rank=rank, dist_group=cfg_group.cpu_group, src=0
)
if rank != 0:
batch = broadcasted_list[0]
batch = stage(batch, server_args)
torch.distributed.barrier()
elif paradigm == StageParallelismType.REPLICATED:
batch = stage(batch, server_args)
return batch
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Base class for all pipeline executors.
"""
import time
from abc import ABC, abstractmethod
from typing import List
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages import PipelineStage
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class Timer:
"""
A very simple timer that doesn't for cuda-stream to be synced
"""
def __init__(self, name="Stage"):
self.name = name
self.start = None
self.end = None
self.elapsed = None
def __enter__(self):
self.start = time.perf_counter()
logger.info(f"[{self.name}] started...")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.end = time.perf_counter()
self.elapsed = self.end - self.start
logger.info(f"[{self.name}] finished in {self.elapsed:.4f} seconds")
return False
class PipelineExecutor(ABC):
"""
Abstract base class for all pipeline executors.
Executors orchestrate the execution of pipeline, with managing the parallel and communications required by stages
"""
def __init__(self, server_args):
self.server_args = server_args
@abstractmethod
def execute(
self,
stages: List[PipelineStage],
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Execute the pipeline stages.
Args:
stages: A list of pipeline stages to execute.
batch: The batch to process.
server_args: The server arguments.
Returns:
The processed batch.
"""
raise NotImplementedError
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Synchronous pipeline executor implementation.
"""
from typing import List
from sglang.multimodal_gen.runtime.pipelines.executors.pipeline_executor import (
PipelineExecutor,
Timer,
logger,
)
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages import PipelineStage
from sglang.multimodal_gen.runtime.server_args import ServerArgs
class SyncExecutor(PipelineExecutor):
"""
A simple synchronous executor that runs stages sequentially.
"""
def execute(
self,
stages: List[PipelineStage],
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Execute the pipeline stages sequentially.
"""
logger.info("Running pipeline stages sequentially with SyncExecutor.")
for stage in stages:
with Timer(stage.__class__.__name__):
batch = stage(batch, server_args)
return batch
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from collections.abc import Hashable
from typing import Any
import torch
import torch.distributed as dist
from safetensors.torch import load_file
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.layers.lora.linear import (
BaseLayerWithLoRA,
get_lora_layer,
replace_submodule,
)
from sglang.multimodal_gen.runtime.loader.utils import get_param_names_mapping
from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_lora
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class LoRAPipeline(ComposedPipelineBase):
"""
Pipeline that supports injecting LoRA adapters into the diffusion transformer.
TODO: support training.
"""
lora_adapters: dict[str, dict[str, torch.Tensor]] = defaultdict(
dict
) # state dicts of loaded lora adapters
cur_adapter_name: str = ""
cur_adapter_path: str = ""
lora_layers: dict[str, BaseLayerWithLoRA] = {}
lora_layers_critic: dict[str, BaseLayerWithLoRA] = {}
server_args: ServerArgs
exclude_lora_layers: list[str] = []
device: torch.device = get_local_torch_device()
lora_target_modules: list[str] | None = None
lora_path: str | None = None
lora_nickname: str = "default"
lora_rank: int | None = None
lora_alpha: int | None = None
lora_initialized: bool = False
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.device = get_local_torch_device()
self.exclude_lora_layers = self.modules[
"transformer"
].config.arch_config.exclude_lora_layers
self.lora_target_modules = self.server_args.lora_target_modules
self.lora_path = self.server_args.lora_path
self.lora_nickname = self.server_args.lora_nickname
if self.lora_path is not None:
self.convert_to_lora_layers()
self.set_lora_adapter(
self.lora_nickname, self.lora_path # type: ignore
) # type: ignore
def is_target_layer(self, module_name: str) -> bool:
if self.lora_target_modules is None:
return True
return any(
target_name in module_name for target_name in self.lora_target_modules
)
def convert_to_lora_layers(self) -> None:
"""
Unified method to convert the transformer to a LoRA transformer.
"""
if self.lora_initialized:
return
self.lora_initialized = True
converted_count = 0
for name, layer in self.modules["transformer"].named_modules():
if not self.is_target_layer(name):
continue
excluded = False
for exclude_layer in self.exclude_lora_layers:
if exclude_layer in name:
excluded = True
break
if excluded:
continue
layer = get_lora_layer(
layer,
lora_rank=self.lora_rank,
lora_alpha=self.lora_alpha,
)
if layer is not None:
self.lora_layers[name] = layer
replace_submodule(self.modules["transformer"], name, layer)
converted_count += 1
logger.info("Converted %d layers to LoRA layers", converted_count)
if "fake_score_transformer" in self.modules:
for name, layer in self.modules["fake_score_transformer"].named_modules():
if not self.is_target_layer(name):
continue
layer = get_lora_layer(
layer,
lora_rank=self.lora_rank,
lora_alpha=self.lora_alpha,
)
if layer is not None:
self.lora_layers_critic[name] = layer
replace_submodule(
self.modules["fake_score_transformer"], name, layer
)
converted_count += 1
logger.info(
"Converted %d layers to LoRA layers in the critic model",
converted_count,
)
def set_lora_adapter(
self, lora_nickname: str, lora_path: str | None = None
): # type: ignore
"""
Load a LoRA adapter into the pipeline and merge it into the transformer.
Args:
lora_nickname: The "nick name" of the adapter when referenced in the pipeline.
lora_path: The path to the adapter, either a local path or a Hugging Face repo id.
"""
if lora_nickname not in self.lora_adapters and lora_path is None:
raise ValueError(
f"Adapter {lora_nickname} not found in the pipeline. Please provide lora_path to load it."
)
if not self.lora_initialized:
self.convert_to_lora_layers()
adapter_updated = False
rank = dist.get_rank()
if lora_path is not None and lora_path != self.cur_adapter_path:
lora_local_path = maybe_download_lora(lora_path)
lora_state_dict = load_file(lora_local_path)
# Map the hf layer names to our custom layer names
param_names_mapping_fn = get_param_names_mapping(
self.modules["transformer"].param_names_mapping
)
lora_param_names_mapping_fn = get_param_names_mapping(
self.modules["transformer"].lora_param_names_mapping
)
to_merge_params: defaultdict[Hashable, dict[Any, Any]] = defaultdict(dict)
for name, weight in lora_state_dict.items():
name = name.replace("diffusion_model.", "")
name = name.replace(".weight", "")
name, _, _ = lora_param_names_mapping_fn(name)
target_name, merge_index, num_params_to_merge = param_names_mapping_fn(
name
)
# for (in_dim, r) @ (r, out_dim), we only merge (r, out_dim * n) where n is the number of linear layers to fuse
# see param mapping in HunyuanVideoArchConfig
if merge_index is not None and "lora_B" in name:
to_merge_params[target_name][merge_index] = weight
if len(to_merge_params[target_name]) == num_params_to_merge:
# cat at output dim according to the merge_index order
sorted_tensors = [
to_merge_params[target_name][i]
for i in range(num_params_to_merge)
]
weight = torch.cat(sorted_tensors, dim=1)
del to_merge_params[target_name]
else:
continue
if target_name in self.lora_adapters[lora_nickname]:
raise ValueError(
f"Target name {target_name} already exists in lora_adapters[{lora_nickname}]"
)
self.lora_adapters[lora_nickname][target_name] = weight.to(self.device)
adapter_updated = True
self.cur_adapter_path = lora_path
logger.info("Rank %d: loaded LoRA adapter %s", rank, lora_path)
if not adapter_updated and self.cur_adapter_name == lora_nickname:
return
self.cur_adapter_name = lora_nickname
# Merge the new adapter
adapted_count = 0
for name, layer in self.lora_layers.items():
lora_A_name = name + ".lora_A"
lora_B_name = name + ".lora_B"
if (
lora_A_name in self.lora_adapters[lora_nickname]
and lora_B_name in self.lora_adapters[lora_nickname]
):
layer.set_lora_weights(
self.lora_adapters[lora_nickname][lora_A_name],
self.lora_adapters[lora_nickname][lora_B_name],
lora_path=lora_path,
)
adapted_count += 1
else:
if rank == 0:
logger.warning(
"LoRA adapter %s does not contain the weights for layer %s. LoRA will not be applied to it.",
lora_path,
name,
)
layer.disable_lora = True
logger.info(
"Rank %d: LoRA adapter %s applied to %d layers",
rank,
lora_path,
adapted_count,
)
def merge_lora_weights(self) -> None:
for name, layer in self.lora_layers.items():
layer.merge_lora_weights()
def unmerge_lora_weights(self) -> None:
for name, layer in self.lora_layers.items():
layer.unmerge_lora_weights()
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Inspired by SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/forward_batch_info.py
"""
Data structures for functional pipeline processing.
This module defines the dataclasses used to pass state between pipeline components
in a functional manner, reducing the need for explicit parameter passing.
"""
import pprint
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any
import PIL.Image
import torch
from sglang.multimodal_gen.configs.sample.base import DataType
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.performance_logger import PerformanceLogger
if TYPE_CHECKING:
from torchcodec.decoders import VideoDecoder
import time
from collections import OrderedDict
from sglang.multimodal_gen.configs.sample.teacache import (
TeaCacheParams,
WanTeaCacheParams,
)
class PipelineLoggingInfo:
"""Simple approach using OrderedDict to track stage metrics."""
def __init__(self):
# OrderedDict preserves insertion order and allows easy access
self.stages: OrderedDict[str, dict[str, Any]] = OrderedDict()
def add_stage_execution_time(self, stage_name: str, execution_time: float):
"""Add execution time for a stage."""
if stage_name not in self.stages:
self.stages[stage_name] = {}
self.stages[stage_name]["execution_time"] = execution_time
self.stages[stage_name]["timestamp"] = time.time()
def add_stage_metric(self, stage_name: str, metric_name: str, value: Any):
"""Add any metric for a stage."""
if stage_name not in self.stages:
self.stages[stage_name] = {}
self.stages[stage_name][metric_name] = value
def get_stage_info(self, stage_name: str) -> dict[str, Any]:
"""Get all info for a specific stage."""
return self.stages.get(stage_name, {})
def get_execution_order(self) -> list[str]:
"""Get stages in execution order."""
return list(self.stages.keys())
def get_total_execution_time(self) -> float:
"""Get total pipeline execution time."""
return sum(stage.get("execution_time", 0) for stage in self.stages.values())
@dataclass
class Req:
"""
Complete state passed through the pipeline execution.
This dataclass contains all information needed during the diffusion pipeline
execution, allowing methods to update specific components without needing
to manage numerous individual parameters.
"""
# TODO(will): double check that args are separate from server_args
# properly. Also maybe think about providing an abstraction for pipeline
# specific arguments.
data_type: DataType
request_id: str | None = None
generator: torch.Generator | list[torch.Generator] | None = None
# Image inputs
image_path: str | None = None
# Image encoder hidden states
image_embeds: list[torch.Tensor] = field(default_factory=list)
pil_image: torch.Tensor | PIL.Image.Image | None = None
pixel_values: torch.Tensor | PIL.Image.Image | None = None
preprocessed_image: torch.Tensor | None = None
# Text inputs
prompt: str | list[str] | None = None
negative_prompt: str | list[str] | None = None
prompt_path: str | None = None
output_path: str = "outputs/"
# without extension
output_file_name: str | None = None
output_file_ext: str | None = None
# Primary encoder embeddings
prompt_embeds: list[torch.Tensor] | torch.Tensor = field(default_factory=list)
negative_prompt_embeds: list[torch.Tensor] | None = None
prompt_attention_mask: list[torch.Tensor] | None = None
negative_attention_mask: list[torch.Tensor] | None = None
clip_embedding_pos: list[torch.Tensor] | None = None
clip_embedding_neg: list[torch.Tensor] | None = None
pooled_embeds: list[torch.Tensor] = field(default_factory=list)
neg_pooled_embeds: list[torch.Tensor] = field(default_factory=list)
# Additional text-related parameters
max_sequence_length: int | None = None
prompt_template: dict[str, Any] | None = None
do_classifier_free_guidance: bool = False
# Batch info
num_outputs_per_prompt: int = 1
seed: int | None = None
seeds: list[int] | None = None
# Tracking if embeddings are already processed
is_prompt_processed: bool = False
# Latent tensors
latents: torch.Tensor | None = None
raw_latent_shape: torch.Tensor | None = None
noise_pred: torch.Tensor | None = None
image_latent: torch.Tensor | None = None
# Latent dimensions
height_latents: list[int] | int | None = None
width_latents: list[int] | int | None = None
num_frames: list[int] | int = 1 # Default for image models
num_frames_round_down: bool = (
False # Whether to round down num_frames if it's not divisible by num_gpus
)
# Original dimensions (before VAE scaling)
height: list[int] | int | None = None
width: list[int] | int | None = None
fps: list[int] | int | None = None
height_not_provided: bool = False
width_not_provided: bool = False
# Timesteps
timesteps: torch.Tensor | None = None
timestep: torch.Tensor | float | int | None = None
step_index: int | None = None
boundary_ratio: float | None = None
# Scheduler parameters
num_inference_steps: int = 50
guidance_scale: float = 1.0
guidance_scale_2: float | None = None
guidance_rescale: float = 0.0
eta: float = 0.0
sigmas: list[float] | None = None
n_tokens: int | None = None
# Other parameters that may be needed by specific schedulers
extra_step_kwargs: dict[str, Any] = field(default_factory=dict)
# Component modules (populated by the pipeline)
modules: dict[str, Any] = field(default_factory=dict)
return_trajectory_latents: bool = False
return_trajectory_decoded: bool = False
trajectory_timesteps: list[torch.Tensor] | None = None
trajectory_latents: torch.Tensor | None = None
# Extra parameters that might be needed by specific pipeline implementations
extra: dict[str, Any] = field(default_factory=dict)
# Misc
save_output: bool = True
return_frames: bool = False
# TeaCache parameters
enable_teacache: bool = False
teacache_params: TeaCacheParams | WanTeaCacheParams | None = None
# STA parameters
STA_param: list | None = None
is_cfg_negative: bool = False
mask_search_final_result_pos: list[list] | None = None
mask_search_final_result_neg: list[list] | None = None
# VSA parameters
VSA_sparsity: float = 0.0
perf_logger: PerformanceLogger | None = None
# profile
profile: bool = False
num_profiled_timesteps: int = 8
# debugging
debug: bool = False
# results
output: torch.Tensor | None = None
@property
def batch_size(self):
# Determine batch size
if isinstance(self.prompt, list):
batch_size = len(self.prompt)
elif self.prompt is not None:
batch_size = 1
else:
batch_size = self.prompt_embeds[0].shape[0]
# Adjust batch size for number of videos per prompt
batch_size *= self.num_outputs_per_prompt
return batch_size
def __post_init__(self):
"""Initialize dependent fields after dataclass initialization."""
# Set do_classifier_free_guidance based on guidance scale and negative prompt
if self.guidance_scale > 1.0 and self.negative_prompt is not None:
self.do_classifier_free_guidance = True
if self.negative_prompt_embeds is None:
self.negative_prompt_embeds = []
if self.guidance_scale_2 is None:
self.guidance_scale_2 = self.guidance_scale
if self.perf_logger is None:
self.perf_logger = PerformanceLogger(self.request_id)
def set_width_and_height(self, server_args: ServerArgs):
if self.height is None or self.width is None:
width, height = server_args.pipeline_config.set_width_and_height(
self.width, self.height, self.pil_image
)
self.width = width
self.height = height
if self.height is None or self.width is None:
self.width = 1280
self.height = 720
def __str__(self):
return pprint.pformat(asdict(self), indent=2, width=120)
@dataclass
class ForwardBatch: ...
@dataclass
class OutputBatch:
"""
Final output (after pipeline completion)
"""
output: torch.Tensor | None = None
trajectory_timesteps: list[torch.Tensor] | None = None
trajectory_latents: torch.Tensor | None = None
trajectory_decoded: list[torch.Tensor] | None = None
error: str | None = None
# Logging info
logging_info: PipelineLoggingInfo = field(default_factory=PipelineLoggingInfo)
@dataclass
class PreprocessBatch(Req):
video_loader: list["VideoDecoder"] | list[str] = field(default_factory=list)
video_file_name: list[str] = field(default_factory=list)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/models/registry.py
# and https://github.com/sgl-project/sglang/blob/v0.4.3/python/sglang/srt/models/registry.py
import dataclasses
import importlib
import pkgutil
from collections.abc import Set
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.pipelines.lora_pipeline import LoRAPipeline
from sglang.multimodal_gen.runtime.server_args import WorkloadType
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
_PREPROCESS_WORKLOAD_TYPE_TO_PIPELINE_NAME: dict[WorkloadType, str] = {
WorkloadType.I2V: "PreprocessPipelineI2V",
WorkloadType.T2V: "PreprocessPipelineT2V",
}
class PipelineType(str, Enum):
"""
Enumeration for different pipeline types.
Inherits from str to allow string comparison for backward compatibility.
"""
BASIC = "basic"
PREPROCESS = "preprocess"
@classmethod
def from_string(cls, value: str) -> "PipelineType":
"""Convert string to PipelineType enum."""
try:
return cls(value.lower())
except ValueError:
raise ValueError(
f"Invalid pipeline type: {value}. Must be one of: {', '.join([t.value for t in cls])}"
) from None
@classmethod
def choices(cls) -> list[str]:
"""Get all available choices as strings."""
return [pipeline_type.value for pipeline_type in cls]
@dataclass
class _PipelineRegistry:
# Keyed by pipeline_type -> architecture -> pipeline_name
# pipelines[pipeline_type][architecture][pipeline_name] = pipeline_cls
pipelines: dict[str, dict[str, type[ComposedPipelineBase] | None]] = (
dataclasses.field(default_factory=dict)
)
def get_supported_archs(
self, pipeline_name_in_config: str, pipeline_type: PipelineType
) -> Set[str]:
"""Get supported architectures, optionally filtered by pipeline type and workload type."""
return set(self.pipelines[pipeline_type.value].keys())
def _load_preprocess_pipeline_cls(
self, workload_type: WorkloadType
) -> type[ComposedPipelineBase] | None:
pipeline_name = _PREPROCESS_WORKLOAD_TYPE_TO_PIPELINE_NAME[workload_type]
return self.pipelines[PipelineType.PREPROCESS.value][pipeline_name]
def _try_load_pipeline_cls(
self,
pipeline_name_in_config: str,
pipeline_type: PipelineType,
workload_type: WorkloadType,
) -> type[ComposedPipelineBase] | type[LoRAPipeline] | None:
"""Try to load a pipeline class for the given architecture, pipeline type, and workload type."""
if pipeline_type.value not in self.pipelines:
return None
try:
if pipeline_type == PipelineType.PREPROCESS:
return self._load_preprocess_pipeline_cls(workload_type)
elif pipeline_type == PipelineType.BASIC:
return self.pipelines[pipeline_type.value][pipeline_name_in_config]
else:
raise ValueError(f"Invalid pipeline type: {pipeline_type.value}")
except KeyError as e:
logger.error(
f"Please check if the ComposedPipeline class has been defined associated with {pipeline_type.value}.{pipeline_name_in_config}"
)
raise e
return None
def resolve_pipeline_cls(
self,
pipeline_name_in_config: str,
pipeline_type: PipelineType,
workload_type: WorkloadType,
) -> type[ComposedPipelineBase] | type[LoRAPipeline]:
"""Resolve pipeline class based on pipeline name in the config, pipeline type, and workload type."""
if not pipeline_name_in_config:
logger.warning("No pipeline architecture is specified")
pipeline_cls = self._try_load_pipeline_cls(
pipeline_name_in_config, pipeline_type, workload_type
)
if pipeline_cls is not None:
return pipeline_cls
supported_archs = self.get_supported_archs(
pipeline_name_in_config, pipeline_type
)
raise ValueError(
f"Pipeline architecture '{pipeline_name_in_config}' is not supported for pipeline type '{pipeline_type.value}' "
f"and workload type '{workload_type.value}'. "
f"Supported architectures: {supported_archs}"
)
@lru_cache
def import_pipeline_classes(
pipeline_types: list[PipelineType] | PipelineType | None = None,
) -> dict[str, dict[str, type[ComposedPipelineBase] | None]]:
"""
Import pipeline classes based on the pipeline type and workload type.
Args:
pipeline_types: The pipeline types to load (basic, preprocess).
If None, loads all types.
Returns:
A three-level nested dictionary:
{pipeline_type: {architecture_name: {pipeline_name: pipeline_cls}}}
e.g., {"basic": {"wan": {"WanPipeline": WanPipeline}}}
"""
type_to_pipeline_dict: dict[str, dict[str, type[ComposedPipelineBase] | None]] = {}
package_name: str = "sglang.multimodal_gen.runtime.architectures"
# Determine which pipeline types to scan
if isinstance(pipeline_types, list):
pipeline_types_to_scan = [
pipeline_type.value for pipeline_type in pipeline_types
]
elif isinstance(pipeline_types, PipelineType):
pipeline_types_to_scan = [pipeline_types.value]
else:
pipeline_types_to_scan = [pt.value for pt in PipelineType]
logger.info("Loading pipelines for types: %s", pipeline_types_to_scan)
for pipeline_type_str in pipeline_types_to_scan:
# Try to load from pipeline-type-specific directory first
pipeline_type_package_name = f"{package_name}.{pipeline_type_str}"
pipeline_dict: dict[str, type[ComposedPipelineBase] | None] = {}
try:
pipeline_type_package = importlib.import_module(pipeline_type_package_name)
logger.debug("Successfully imported %s", pipeline_type_package_name)
for _, arch, ispkg in pkgutil.iter_modules(pipeline_type_package.__path__):
arch_package_name = f"{pipeline_type_package_name}.{arch}"
if ispkg:
arch_package = importlib.import_module(arch_package_name)
for _, module_name, ispkg in pkgutil.walk_packages(
arch_package.__path__, arch_package_name + "."
):
if not ispkg:
pipeline_module = importlib.import_module(module_name)
if hasattr(pipeline_module, "EntryClass"):
entry_cls_list = pipeline_module.EntryClass
if not isinstance(entry_cls_list, list):
entry_cls_list = [entry_cls_list]
if isinstance(pipeline_module.EntryClass, list):
pipeline_names = [
pipeline.__name__
for pipeline in pipeline_module.EntryClass
]
else:
pipeline_names = [
pipeline_module.EntryClass.__name__
]
for entry_cls, pipeline_name in zip(
entry_cls_list, pipeline_names
):
assert (
pipeline_name not in pipeline_dict
), f"Duplicated pipeline implementation for {pipeline_name} in {pipeline_type_str}.{arch_package_name}"
assert hasattr(
entry_cls, "pipeline_name"
), f"{entry_cls}"
pipeline_dict[pipeline_name] = entry_cls
type_to_pipeline_dict[pipeline_type_str] = pipeline_dict
except ImportError as e:
raise ImportError(
f"Could not import {pipeline_type_package_name} when importing pipeline classes: {e}"
) from None
# Log summary
total_pipelines = sum(
len(pipeline_dict) for pipeline_dict in type_to_pipeline_dict.values()
)
logger.info(
"Loaded %d pipeline classes across %d types",
total_pipelines,
len(pipeline_types_to_scan),
)
return type_to_pipeline_dict
def get_pipeline_registry(
pipeline_type: PipelineType | str | None = None,
) -> _PipelineRegistry:
"""
Get a pipeline registry for the specified mode, pipeline type, and workload type.
Args:
pipeline_type: Pipeline type to load. If None and mode is provided, will be derived from mode.
Returns:
A pipeline registry instance.
"""
if isinstance(pipeline_type, str):
pipeline_type = PipelineType.from_string(pipeline_type)
pipeline_classes = import_pipeline_classes(pipeline_type)
return _PipelineRegistry(pipeline_classes)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Pipeline stages for diffusion models.
This package contains the various stages that can be composed to create
complete diffusion pipelines.
"""
from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage
from sglang.multimodal_gen.runtime.pipelines.stages.causal_denoising import (
CausalDMDDenoisingStage,
)
from sglang.multimodal_gen.runtime.pipelines.stages.conditioning import (
ConditioningStage,
)
from sglang.multimodal_gen.runtime.pipelines.stages.decoding import DecodingStage
from sglang.multimodal_gen.runtime.pipelines.stages.denoising import DenoisingStage
from sglang.multimodal_gen.runtime.pipelines.stages.denoising_dmd import (
DmdDenoisingStage,
)
from sglang.multimodal_gen.runtime.pipelines.stages.encoding import EncodingStage
from sglang.multimodal_gen.runtime.pipelines.stages.image_encoding import (
ImageEncodingStage,
ImageVAEEncodingStage,
)
from sglang.multimodal_gen.runtime.pipelines.stages.input_validation import (
InputValidationStage,
)
from sglang.multimodal_gen.runtime.pipelines.stages.latent_preparation import (
LatentPreparationStage,
)
from sglang.multimodal_gen.runtime.pipelines.stages.stepvideo_encoding import (
StepvideoPromptEncodingStage,
)
from sglang.multimodal_gen.runtime.pipelines.stages.text_encoding import (
TextEncodingStage,
)
from sglang.multimodal_gen.runtime.pipelines.stages.timestep_preparation import (
TimestepPreparationStage,
)
__all__ = [
"PipelineStage",
"InputValidationStage",
"TimestepPreparationStage",
"LatentPreparationStage",
"ConditioningStage",
"DenoisingStage",
"DmdDenoisingStage",
"CausalDMDDenoisingStage",
"EncodingStage",
"DecodingStage",
"ImageEncodingStage",
"ImageVAEEncodingStage",
"TextEncodingStage",
"StepvideoPromptEncodingStage",
]
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo
# SPDX-License-Identifier: Apache-2.0
"""
Base classes for pipeline stages.
This module defines the abstract base classes for pipeline stages that can be
composed to create complete diffusion pipelines.
"""
import time
import traceback
from abc import ABC, abstractmethod
from enum import Enum, auto
import torch
import sglang.multimodal_gen.envs as envs
from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req
from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult
from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class StageParallelismType(Enum):
# execute on all gpus
REPLICATED = auto()
# executed on main rank only
MAIN_RANK_ONLY = auto()
# this stage requires a cfg-parallel
CFG_PARALLEL = auto()
class StageVerificationError(Exception):
"""Exception raised when stage verification fails."""
pass
class PipelineStage(ABC):
"""
Abstract base class for all pipeline stages.
A pipeline stage represents a discrete step in the diffusion process that can be
composed with other stages to create a complete pipeline. Each stage is responsible
for a specific part of the process, such as prompt encoding, latent preparation, etc.
"""
def __init__(self):
self.server_args = get_global_server_args()
def log_info(self, msg, *args):
"""Logs an informational message with the stage name as a prefix."""
logger.info(f"[{self.__class__.__name__}] {msg}", *args)
def log_warning(self, msg, *args):
"""Logs a warning message with the stage name as a prefix."""
logger.warning(f"[{self.__class__.__name__}] {msg}", *args)
def log_error(self, msg, *args):
"""Logs an error message with the stage name as a prefix."""
logger.error(f"[{self.__class__.__name__}] {msg}", *args)
def log_debug(self, msg, *args):
"""Logs a debug message with the stage name as a prefix."""
logger.debug(f"[{self.__class__.__name__}] {msg}", *args)
def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""
Verify the input for the stage.
Example:
from sglang.multimodal_gen.runtime.pipelines.stages.validators import V, VerificationResult
def verify_input(self, batch, server_args):
result = VerificationResult()
result.add_check("height", batch.height, V.positive_int_divisible(8))
result.add_check("width", batch.width, V.positive_int_divisible(8))
result.add_check("image_latent", batch.image_latent, V.is_tensor)
return result
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
A VerificationResult containing the verification status.
"""
# Default implementation - no verification
return VerificationResult()
# execute on all ranks by default
@property
def parallelism_type(self) -> StageParallelismType:
# if get_global_server_args().enable_cfg_parallel:
# return StageParallelismType.MAIN_RANK_ONLY
return StageParallelismType.REPLICATED
def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult:
"""
Verify the output for the stage.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
A VerificationResult containing the verification status.
"""
# Default implementation - no verification
return VerificationResult()
def _run_verification(
self,
verification_result: VerificationResult,
stage_name: str,
verification_type: str,
) -> None:
"""
Run verification and raise errors if any checks fail.
Args:
verification_result: Results from verify_input or verify_output
stage_name: Name of the current stage
verification_type: "input" or "output"
"""
if not verification_result.is_valid():
failed_fields = verification_result.get_failed_fields()
if failed_fields:
# Get detailed failure information
detailed_summary = verification_result.get_failure_summary()
failed_fields_str = ", ".join(failed_fields)
error_msg = (
f"{verification_type.capitalize()} verification failed for {stage_name}: "
f"Failed fields: {failed_fields_str}\n"
f"Details: {detailed_summary}"
)
raise StageVerificationError(error_msg)
@property
def device(self) -> torch.device:
"""Get the device for this stage."""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def set_logging(self, enable: bool):
"""
Enable or disable logging for this stage.
Args:
enable: Whether to enable logging.
"""
self._enable_logging = enable
def __call__(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Execute the stage's processing on the batch with optional verification and logging.
Should not be overridden by subclasses.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
The updated batch information after this stage's processing.
"""
stage_name = self.__class__.__name__
# Check if verification is enabled (simple approach for prototype)
enable_verification = getattr(server_args, "enable_stage_verification", False)
if enable_verification:
# Pre-execution input verification
try:
input_result = self.verify_input(batch, server_args)
self._run_verification(input_result, stage_name, "input")
except Exception as e:
logger.error("Input verification failed for %s: %s", stage_name, str(e))
raise
# Execute the actual stage logic
if envs.SGL_DIFFUSION_STAGE_LOGGING:
logger.info("[%s] Starting execution", stage_name)
start_time = time.perf_counter()
try:
result = self.forward(batch, server_args)
execution_time = time.perf_counter() - start_time
logger.info(
"[%s] Execution completed in %s ms",
stage_name,
execution_time * 1000,
)
batch.logging_info.add_stage_execution_time(stage_name, execution_time)
except Exception as e:
execution_time = time.perf_counter() - start_time
logger.error(
"[%s] Error during execution after %s ms: %s",
stage_name,
execution_time * 1000,
e,
)
logger.error("[%s] Traceback: %s", stage_name, traceback.format_exc())
raise
else:
# Direct execution (current behavior)
result = self.forward(batch, server_args)
if enable_verification:
# Post-execution output verification
try:
output_result = self.verify_output(result, server_args)
self._run_verification(output_result, stage_name, "output")
except Exception as e:
logger.error(
"Output verification failed for %s: %s", stage_name, str(e)
)
raise
return result
@abstractmethod
def forward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Forward pass of the stage's processing.
This method should be implemented by subclasses to provide the forward
processing logic for the stage.
Args:
batch: The current batch information.
server_args: The inference arguments.
Returns:
The updated batch information after this stage's processing.
"""
raise NotImplementedError
def backward(
self,
batch: Req,
server_args: ServerArgs,
) -> Req:
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment