Commit c1cacde6 authored by weishb's avatar weishb
Browse files

vllm-omni_0.15.0.rc1+fix1 first commit

parent 35607782
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
# Model-specific polynomial coefficients for rescaling L1 distances
# These coefficients account for model-specific characteristics in how embeddings change
# Source: TeaCache paper and ComfyUI-TeaCache empirical tuning
_MODEL_COEFFICIENTS = {
# FLUX transformer coefficients from TeaCache paper
"FluxTransformer2DModel": [
4.98651651e02,
-2.83781631e02,
5.58554382e01,
-3.82021401e00,
2.64230861e-01,
],
# Qwen-Image transformer coefficients from ComfyUI-TeaCache
# Tuned specifically for Qwen's dual-stream transformer architecture
# Used for all Qwen-Image Family pipelines, in general
"QwenImageTransformer2DModel": [
-4.50000000e02,
2.80000000e02,
-4.50000000e01,
3.20000000e00,
-2.00000000e-02,
],
# Bagel transformer coefficients
# Using Qwen's coefficients as reasonable default given shared architecture
"Bagel": [1.33313129e06, -1.68644226e05, 7.95050740e03, -1.63747873e02, 1.26352397e00],
# Z-Image transformer coefficients
# Copied from Qwen-Image, need to be tuned specifically for Z-Image in future
"ZImageTransformer2DModel": [
-4.50000000e02,
2.80000000e02,
-4.50000000e01,
3.20000000e00,
-2.00000000e-02,
],
}
@dataclass
class TeaCacheConfig:
"""
Configuration for TeaCache applied to transformer models.
TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up
diffusion model inference by reusing transformer block computations when consecutive
timestep embeddings are similar.
Args:
rel_l1_thresh: Threshold for accumulated relative L1 distance. When below threshold,
cached residual is reused. Values in [0.1, 0.3] work best:
- 0.2: ~1.5x speedup with minimal quality loss
- 0.4: ~1.8x speedup with slight quality loss
- 0.6: ~2.0x speedup with noticeable quality loss
coefficients: Polynomial coefficients for rescaling L1 distance. If None, uses
model-specific defaults based on transformer_type.
transformer_type: Transformer class name (e.g., "QwenImageTransformer2DModel").
Auto-detected from pipeline.transformer.__class__.__name__ in backend.
Defaults to "QwenImageTransformer2DModel".
"""
rel_l1_thresh: float = 0.2
coefficients: list[float] | None = None
transformer_type: str = "QwenImageTransformer2DModel"
def __post_init__(self) -> None:
"""Validate and set default coefficients."""
if self.rel_l1_thresh <= 0:
raise ValueError(f"rel_l1_thresh must be positive, got {self.rel_l1_thresh}")
if self.coefficients is None:
# Use model-specific coefficients, explicitly check if the type exists or not
if self.transformer_type not in _MODEL_COEFFICIENTS:
raise KeyError(
f"Cannot find coefficients for {self.transformer_type}. "
f"Supported: {list(_MODEL_COEFFICIENTS.keys())}"
)
self.coefficients = _MODEL_COEFFICIENTS[self.transformer_type]
if len(self.coefficients) != 5:
raise ValueError(f"coefficients must contain exactly 5 elements, got {len(self.coefficients)}")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Model-specific extractors for TeaCache.
This module provides a registry of extractor functions that know how to extract
modulated inputs from different transformer architectures. Adding support for
a new model requires only adding a new extractor function to the registry.
With Option B enhancement, extractors now return a CacheContext object containing
all model-specific information needed for generic caching, including preprocessing,
transformer execution, and postprocessing logic.
"""
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
import torch
import torch.nn as nn
from vllm_omni.diffusion.forward_context import get_forward_context
@dataclass
class CacheContext:
"""
Context object containing all model-specific information for caching.
This allows the TeaCacheHook to remain completely generic - all model-specific
logic is encapsulated in the extractor that returns this context.
Attributes:
modulated_input: Tensor used for cache decision (similarity comparison).
Must be a torch.Tensor extracted from the first transformer block,
typically after applying normalization and modulation.
hidden_states: Current hidden states (will be modified by caching).
Must be a torch.Tensor representing the main image/latent states
after preprocessing but before transformer blocks.
encoder_hidden_states: Optional encoder states (for dual-stream models).
Set to None for single-stream models (e.g., Flux).
For dual-stream models (e.g., Qwen), contains text encoder outputs.
temb: Timestep embedding tensor.
Must be a torch.Tensor containing the timestep conditioning.
run_transformer_blocks: Callable that executes model-specific transformer blocks.
Signature: () -> tuple[torch.Tensor, ...]
Returns:
tuple containing:
- [0]: processed hidden_states (required)
- [1]: processed encoder_hidden_states (optional, only for dual-stream)
Example for single-stream:
def run_blocks():
h = hidden_states
for block in module.transformer_blocks:
h = block(h, temb=temb)
return (h,)
Example for dual-stream:
def run_blocks():
h, e = hidden_states, encoder_hidden_states
for block in module.transformer_blocks:
e, h = block(h, e, temb=temb)
return (h, e)
postprocess: Callable that does model-specific output postprocessing.
Signature: (torch.Tensor) -> Union[torch.Tensor, Transformer2DModelOutput, tuple]
Takes the processed hidden_states and applies final transformations
(normalization, projection) to produce the model output.
Example:
def postprocess(h):
h = module.norm_out(h, temb)
output = module.proj_out(h)
return Transformer2DModelOutput(sample=output)
extra_states: Optional dict for additional model-specific state.
Use this for models that need to pass additional context beyond
the standard fields.
"""
modulated_input: torch.Tensor
hidden_states: torch.Tensor
encoder_hidden_states: torch.Tensor | None
temb: torch.Tensor
run_transformer_blocks: Callable[[], tuple[torch.Tensor, ...]]
postprocess: Callable[[torch.Tensor], Any]
extra_states: dict[str, Any] | None = None
def validate(self) -> None:
"""
Validate that the CacheContext contains valid data.
Raises:
TypeError: If fields have wrong types
ValueError: If tensors have invalid properties
RuntimeError: If callables fail basic invocation tests
This method should be called after creating a CacheContext to catch
common developer errors early with clear error messages.
"""
# Validate tensor fields
if not isinstance(self.modulated_input, torch.Tensor):
raise TypeError(f"modulated_input must be torch.Tensor, got {type(self.modulated_input)}")
if not isinstance(self.hidden_states, torch.Tensor):
raise TypeError(f"hidden_states must be torch.Tensor, got {type(self.hidden_states)}")
if self.encoder_hidden_states is not None and not isinstance(self.encoder_hidden_states, torch.Tensor):
raise TypeError(
f"encoder_hidden_states must be torch.Tensor or None, got {type(self.encoder_hidden_states)}"
)
if not isinstance(self.temb, torch.Tensor):
raise TypeError(f"temb must be torch.Tensor, got {type(self.temb)}")
# Validate callables
if not callable(self.run_transformer_blocks):
raise TypeError(f"run_transformer_blocks must be callable, got {type(self.run_transformer_blocks)}")
if not callable(self.postprocess):
raise TypeError(f"postprocess must be callable, got {type(self.postprocess)}")
# Validate tensor shapes are compatible
if self.modulated_input.shape[0] != self.hidden_states.shape[0]:
raise ValueError(
f"Batch size mismatch: modulated_input has batch size "
f"{self.modulated_input.shape[0]}, but hidden_states has "
f"{self.hidden_states.shape[0]}"
)
# Validate devices match
if self.modulated_input.device != self.hidden_states.device:
raise ValueError(
f"Device mismatch: modulated_input on {self.modulated_input.device}, "
f"hidden_states on {self.hidden_states.device}"
)
def extract_qwen_context(
module: nn.Module,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_mask: torch.Tensor,
timestep: torch.Tensor | float | int,
img_shapes: torch.Tensor,
txt_seq_lens: torch.Tensor,
guidance: torch.Tensor | None = None,
additional_t_cond: torch.Tensor | None = None,
attention_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> CacheContext:
"""
Extract cache context for QwenImageTransformer2DModel.
This is the ONLY Qwen-specific code needed for TeaCache support.
It encapsulates preprocessing, modulated input extraction, transformer execution,
and postprocessing logic.
Args:
module: QwenImageTransformer2DModel instance
hidden_states: Input hidden states tensor
encoder_hidden_states: Text encoder outputs
encoder_hidden_states_mask: Mask for text encoder
timestep: Current diffusion timestep
img_shapes: Image shapes for position embedding
txt_seq_lens: Text sequence lengths
guidance: Optional guidance scale for CFG
additional_t_cond: Optional additional timestep conditioning
attention_kwargs: Additional attention arguments
**kwargs: Additional keyword arguments ignored by this extractor
Returns:
CacheContext with all information needed for generic caching
"""
from diffusers.models.modeling_outputs import Transformer2DModelOutput
if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0:
raise ValueError("Module must have transformer_blocks")
# ============================================================================
# PREPROCESSING (Qwen-specific)
# ============================================================================
hidden_states = module.img_in(hidden_states)
timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype)
encoder_hidden_states = module.txt_norm(encoder_hidden_states)
encoder_hidden_states = module.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
temb = (
module.time_text_embed(timestep, hidden_states, additional_t_cond)
if guidance is None
else module.time_text_embed(timestep, guidance, hidden_states, additional_t_cond)
)
image_rotary_emb = module.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
# ============================================================================
# EXTRACT MODULATED INPUT (for cache decision)
# ============================================================================
block = module.transformer_blocks[0]
img_mod_params = block.img_mod(temb)
img_mod1, _ = img_mod_params.chunk(2, dim=-1)
img_modulated, _ = block.img_norm1(hidden_states, img_mod1)
# ============================================================================
# DEFINE TRANSFORMER EXECUTION (Qwen-specific)
# ============================================================================
def run_transformer_blocks():
"""Execute all Qwen transformer blocks."""
h = hidden_states
e = encoder_hidden_states
encoder_mask = encoder_hidden_states_mask
hidden_states_mask = None # default
if module.parallel_config is not None and module.parallel_config.sequence_parallel_size > 1:
ctx = get_forward_context()
if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0:
# Create mask for the full (padded) sequence
# valid positions = True, padding positions = False
batch_size = hidden_states.shape[0]
padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size
hidden_states_mask = torch.ones(
batch_size,
padded_seq_len,
dtype=torch.bool,
device=hidden_states.device,
)
hidden_states_mask[:, ctx.sp_original_seq_len :] = False
# if mask is all true, set it to None
if hidden_states_mask is not None and hidden_states_mask.all():
hidden_states_mask = None
if encoder_mask is not None and encoder_mask.all():
encoder_mask = None
for block in module.transformer_blocks:
e, h = block(
hidden_states=h,
encoder_hidden_states=e,
encoder_hidden_states_mask=encoder_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
hidden_states_mask=hidden_states_mask,
)
return (h, e)
# ============================================================================
# DEFINE POSTPROCESSING (Qwen-specific)
# ============================================================================
return_dict = kwargs.get("return_dict", True)
def postprocess(h):
"""Apply Qwen-specific output postprocessing."""
h = module.norm_out(h, temb)
output = module.proj_out(h)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
# ============================================================================
# RETURN CONTEXT
# ============================================================================
return CacheContext(
modulated_input=img_modulated,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
run_transformer_blocks=run_transformer_blocks,
postprocess=postprocess,
)
def extract_bagel_context(
module: nn.Module,
x_t: torch.Tensor,
timestep: torch.Tensor | float | int,
packed_vae_token_indexes: torch.LongTensor,
packed_vae_position_ids: torch.LongTensor,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_indexes: torch.LongTensor,
packed_position_ids: torch.LongTensor,
packed_seqlens: torch.IntTensor,
key_values_lens: torch.IntTensor,
past_key_values: Any,
packed_key_value_indexes: torch.LongTensor,
**kwargs: Any,
) -> CacheContext:
"""
Extract cache context for Bagel model.
Args:
module: Bagel instance
x_t: Latent image input
timestep: Current timestep
packed_vae_token_indexes: Indexes for VAE tokens in packed sequence
packed_vae_position_ids: Position IDs for VAE tokens
packed_text_ids: Text token IDs
packed_text_indexes: Indexes for text tokens in packed sequence
packed_indexes: Global indexes
packed_position_ids: Global position IDs
packed_seqlens: Sequence lengths
key_values_lens: KV cache lengths
past_key_values: KV cache
packed_key_value_indexes: KV cache indexes
**kwargs: Additional keyword arguments
Returns:
CacheContext with all information needed for generic caching
"""
# 1. Embed text
packed_text_embedding = module.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), module.hidden_size))
packed_sequence[packed_text_indexes] = packed_text_embedding
# 2. Embed timestep
if not isinstance(timestep, torch.Tensor):
timestep = torch.tensor([timestep], device=x_t.device)
if timestep.dim() == 0:
timestep = timestep.unsqueeze(0)
# 3. Embed image (x_t)
packed_pos_embed = module.latent_pos_embed(packed_vae_position_ids)
packed_timestep_embeds = module.time_embedder(timestep)
x_t_emb = module.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
if x_t_emb.dtype != packed_sequence.dtype:
x_t_emb = x_t_emb.to(packed_sequence.dtype)
packed_sequence[packed_vae_token_indexes] = x_t_emb
# Use the full packed sequence as modulated input to match hidden_states size
modulated_input = packed_sequence
def run_transformer_blocks():
extra_inputs = {}
if module.use_moe:
extra_inputs = {
"mode": "gen",
"packed_vae_token_indexes": packed_vae_token_indexes,
"packed_text_indexes": packed_text_indexes,
}
output = module.language_model.forward(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=packed_position_ids,
packed_query_indexes=packed_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
return (output.packed_query_sequence,)
def postprocess(h):
v_t = module.llm2vae(h)
v_t = v_t[packed_vae_token_indexes]
return v_t
return CacheContext(
modulated_input=modulated_input,
hidden_states=packed_sequence, # Use full packed sequence
encoder_hidden_states=None,
temb=packed_timestep_embeds, # Approximate
run_transformer_blocks=run_transformer_blocks,
postprocess=postprocess,
)
def extract_zimage_context(
module: nn.Module,
x: list[torch.Tensor],
t: torch.Tensor,
cap_feats: list[torch.Tensor],
patch_size: int = 2,
f_patch_size: int = 1,
**kwargs: Any,
) -> CacheContext:
"""
Extract cache context for ZImageTransformer2DModel.
This is the ONLY Z-Image-specific code needed for TeaCache support.
It encapsulates preprocessing, modulated input extraction, transformer execution,
and postprocessing logic.
Args:
module: ZImageTransformer2DModel instance
x: List of image tensors per batch item
t: Timestep tensor
cap_feats: List of caption feature tensors per batch item
patch_size: Patch size for patchification (default: 2)
f_patch_size: Frame patch size (default: 1)
**kwargs: Additional keyword arguments ignored by this extractor
Returns:
CacheContext with all information needed for generic caching
"""
from torch.nn.utils.rnn import pad_sequence
if not hasattr(module, "layers") or len(module.layers) == 0:
raise ValueError("Module must have main transformer layers")
bsz = len(x)
device = x[0].device
# ============================================================================
# PREPROCESSING (Z-Image specific)
# ============================================================================
# Scale timestep and create timestep embedding
t_scaled = t * module.t_scale
adaln_input = module.t_embedder(t_scaled)
# Patchify and embed inputs
(
x_patches,
cap_feats_processed,
x_size,
x_pos_ids,
cap_pos_ids,
x_inner_pad_mask,
cap_inner_pad_mask,
) = module.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
# Process image patches through embedder and noise refiner
x_item_seqlens = [len(_) for _ in x_patches]
x_max_item_seqlen = max(x_item_seqlens)
x_embedded = torch.cat(x_patches, dim=0)
x_embedded = module.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_embedded)
# Match adaln_input dtype to x_embedded
adaln_input = adaln_input.type_as(x_embedded)
# Apply pad token
x_embedded[torch.cat(x_inner_pad_mask)] = module.x_pad_token
x_list = list(x_embedded.split(x_item_seqlens, dim=0))
# Compute rope embeddings for image patches
x_cos, x_sin = module.rope_embedder(torch.cat(x_pos_ids, dim=0))
x_cos = list(x_cos.split(x_item_seqlens, dim=0))
x_sin = list(x_sin.split(x_item_seqlens, dim=0))
# Pad sequences for batch processing
x_batched = pad_sequence(x_list, batch_first=True, padding_value=0.0)
x_cos_batched = pad_sequence(x_cos, batch_first=True, padding_value=0.0)
x_sin_batched = pad_sequence(x_sin, batch_first=True, padding_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
# Run noise refiner blocks
for layer in module.noise_refiner:
x_batched = layer(x_batched, x_attn_mask, x_cos_batched, x_sin_batched, adaln_input)
# Process caption features through embedder and context refiner
cap_item_seqlens = [len(_) for _ in cap_feats_processed]
cap_max_item_seqlen = max(cap_item_seqlens)
cap_embedded = torch.cat(cap_feats_processed, dim=0)
cap_embedded = module.cap_embedder(cap_embedded)
cap_embedded[torch.cat(cap_inner_pad_mask)] = module.cap_pad_token
cap_list = list(cap_embedded.split(cap_item_seqlens, dim=0))
# Compute rope embeddings for caption
cap_cos, cap_sin = module.rope_embedder(torch.cat(cap_pos_ids, dim=0))
cap_cos = list(cap_cos.split(cap_item_seqlens, dim=0))
cap_sin = list(cap_sin.split(cap_item_seqlens, dim=0))
# Pad sequences for batch processing
cap_batched = pad_sequence(cap_list, batch_first=True, padding_value=0.0)
cap_cos_batched = pad_sequence(cap_cos, batch_first=True, padding_value=0.0)
cap_sin_batched = pad_sequence(cap_sin, batch_first=True, padding_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
# Run context refiner blocks
for layer in module.context_refiner:
cap_batched = layer(cap_batched, cap_attn_mask, cap_cos_batched, cap_sin_batched)
# Create unified sequence (image + caption)
unified_list = []
unified_cos_list = []
unified_sin_list = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified_list.append(torch.cat([x_batched[i][:x_len], cap_batched[i][:cap_len]]))
unified_cos_list.append(torch.cat([x_cos_batched[i][:x_len], cap_cos_batched[i][:cap_len]]))
unified_sin_list.append(torch.cat([x_sin_batched[i][:x_len], cap_sin_batched[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
unified_max_item_seqlen = max(unified_item_seqlens)
unified = pad_sequence(unified_list, batch_first=True, padding_value=0.0)
unified_cos = pad_sequence(unified_cos_list, batch_first=True, padding_value=0.0)
unified_sin = pad_sequence(unified_sin_list, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
# ============================================================================
# EXTRACT MODULATED INPUT (for cache decision)
# ============================================================================
# Use the first main transformer block's modulation
# The main layers have modulation=True and process the unified sequence
block = module.layers[0]
# Get modulation parameters: scale_msa, gate_msa, scale_mlp, gate_mlp
mod_params = block.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
scale_msa = 1.0 + mod_params[0]
# Extract modulated input: normalized hidden states scaled by modulation
modulated_input = block.attention_norm1(unified) * scale_msa
# ============================================================================
# DEFINE TRANSFORMER EXECUTION (Z-Image specific)
# ============================================================================
def run_transformer_blocks():
"""Execute all Z-Image main transformer blocks."""
h = unified
for layer in module.layers:
h = layer(h, unified_attn_mask, unified_cos, unified_sin, adaln_input)
return (h,)
# ============================================================================
# DEFINE POSTPROCESSING (Z-Image specific)
# ============================================================================
def postprocess(h):
"""Apply Z-Image specific output postprocessing."""
h = module.all_final_layer[f"{patch_size}-{f_patch_size}"](h, adaln_input)
h = list(h.unbind(dim=0))
output = module.unpatchify(h, x_size, patch_size, f_patch_size)
return output, {}
# ============================================================================
# RETURN CONTEXT
# ============================================================================
return CacheContext(
modulated_input=modulated_input,
hidden_states=unified,
encoder_hidden_states=None, # Z-Image uses unified sequence, no separate encoder states
temb=adaln_input,
run_transformer_blocks=run_transformer_blocks,
postprocess=postprocess,
extra_states={
"unified_attn_mask": unified_attn_mask,
"unified_cos": unified_cos,
"unified_sin": unified_sin,
"x_size": x_size,
"x_item_seqlens": x_item_seqlens,
"patch_size": patch_size,
"f_patch_size": f_patch_size,
},
)
# Registry for model-specific extractors
# Key: Transformer class name
# Value: extractor function with signature (module, *args, **kwargs) -> CacheContext
#
# Note: Use the transformer class name as specified in pipelines as TeaCache hooks operate
# on the transformer module and multiple pipelines can share the same transformer.
EXTRACTOR_REGISTRY: dict[str, Callable] = {
"QwenImageTransformer2DModel": extract_qwen_context,
"Bagel": extract_bagel_context,
"ZImageTransformer2DModel": extract_zimage_context,
# Future models:
# "FluxTransformer2DModel": extract_flux_context,
# "CogVideoXTransformer3DModel": extract_cogvideox_context,
}
def register_extractor(transformer_cls_name: str, extractor_fn: Callable) -> None:
"""
Register a new extractor function for a model type.
This allows extending TeaCache support to new models without modifying
the core TeaCache code.
Args:
transformer_cls_name: Transformer model type identifier (class name or type string)
extractor_fn: Function with signature (module, *args, **kwargs) -> CacheContext
Example:
>>> def extract_flux_context(module, hidden_states, timestep, guidance=None, **kwargs):
... # Preprocessing
... temb = module.time_text_embed(timestep, guidance)
... # Extract modulated input
... modulated = module.transformer_blocks[0].norm1(hidden_states, emb=temb)
... # Define execution
... def run_blocks():
... h = hidden_states
... for block in module.transformer_blocks:
... h = block(h, temb=temb)
... return (h,)
... # Define postprocessing
... def postprocess(h):
... return module.proj_out(module.norm_out(h, temb))
... # Return context
... return CacheContext(modulated, hidden_states, None, temb, run_blocks, postprocess)
>>> register_extractor("FluxTransformer2DModel", extract_flux_context)
"""
EXTRACTOR_REGISTRY[transformer_cls_name] = extractor_fn
def get_extractor(transformer_cls_name: str) -> Callable:
"""
Get extractor function for given transformer class.
This function looks up the extractor based on the exact transformer_cls_name string,
which should match the transformer type in the pipeline (i.e., pipeline.transformer.__class__.__name__).
Args:
transformer_cls_name: Transformer class name (e.g., "QwenImageTransformer2DModel")
Must exactly match a key in EXTRACTOR_REGISTRY.
Returns:
Extractor function with signature (module, *args, **kwargs) -> CacheContext
Raises:
ValueError: If model type not found in registry
Example:
>>> # Get extractor for QwenImageTransformer2DModel
>>> extractor = get_extractor("QwenImageTransformer2DModel")
>>> ctx = extractor(transformer, hidden_states, encoder_hidden_states, timestep, ...)
"""
# Direct lookup - no substring matching
if transformer_cls_name in EXTRACTOR_REGISTRY:
return EXTRACTOR_REGISTRY[transformer_cls_name]
# No match found
available_types = list(EXTRACTOR_REGISTRY.keys())
raise ValueError(
f"Unknown model type: '{transformer_cls_name}'. "
f"Available types: {available_types}\n"
f"To add support for a new model, use register_extractor() or add to EXTRACTOR_REGISTRY."
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Hook-based TeaCache implementation for vLLM-Omni.
This module implements a diffusers-style hook system that completely intercepts
the transformer forward pass, eliminating the need for any TeaCache-specific
code in model definitions. Model developers only need to add an extractor function
to support new models.
"""
from __future__ import annotations
from typing import Any
import numpy as np
import torch
from vllm_omni.diffusion.cache.teacache.config import TeaCacheConfig
from vllm_omni.diffusion.cache.teacache.extractors import get_extractor
from vllm_omni.diffusion.cache.teacache.state import TeaCacheState
from vllm_omni.diffusion.distributed.parallel_state import (
get_classifier_free_guidance_rank,
get_classifier_free_guidance_world_size,
)
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook, StateManager
class TeaCacheHook(ModelHook):
"""
ModelHook implementing TeaCache for transformer models.
This hook completely intercepts the transformer's forward pass and implements
adaptive caching based on timestep embedding similarity. It's model-agnostic
and supports multiple model types through extractor functions.
Key features:
- Zero changes to model code
- CFG-aware with separate states for positive/negative branches
- CFG-parallel compatible: properly detects branch identity across ranks
- Model-specific polynomial rescaling
- Auto-detection of model types
Attributes:
config: TeaCache configuration with thresholds and callbacks
rescale_func: Polynomial function for rescaling L1 distances
state_manager: Manages TeaCacheState across forward passes
extractor_fn: Model-specific function to extract modulated input
"""
_HOOK_NAME = "teacache"
def __init__(self, config: TeaCacheConfig):
"""
Initialize TeaCacheHook.
Args:
config: TeaCache configuration object.
"""
super().__init__()
self.config = config
self.rescale_func = np.poly1d(config.coefficients)
self.state_manager = StateManager(TeaCacheState)
self.extractor_fn = None
self._forward_cnt = 0
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
"""
Initialize hook with extractor from config transformer model type.
Args:
module: The module to initialize the hook for.
Returns:
The initialized module.
"""
# Get extractor function based on transformer_type from config
# transformer_type is the transformer class name (e.g., "QwenImageTransformer2DModel")
self.extractor_fn = get_extractor(self.config.transformer_type)
# Set default context
self.state_manager.set_context("teacache")
return module
def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any:
"""
Generic forward handler that works for ANY model.
This method is completely model-agnostic. All model-specific logic
is encapsulated in the extractor function that returns a CacheContext.
The extractor does:
- Model-specific preprocessing
- Extraction of modulated input for cache decision
- Providing transformer execution callable
- Providing postprocessing callable
This hook does:
- CFG-aware state management
- Cache decision logic (generic)
- Residual caching and reuse
Args:
module: Transformer module (any architecture)
*args: Positional arguments for model forward
**kwargs: Keyword arguments for model forward
Returns:
Model output (format depends on model)
"""
# Get model-specific context from extractor
# The extractor encapsulates ALL model-specific logic
ctx = self.extractor_fn(module, *args, **kwargs)
# ============================================================================
# GENERIC CACHING LOGIC (works for all models)
# ============================================================================
# Set context based on CFG branch for separate state tracking
# With CFG-parallel, each rank processes only one branch:
# - cfg_rank 0: positive branch
# - cfg_rank > 0: negative branch
# Without CFG-parallel, branches alternate within a single rank
if getattr(module, "do_true_cfg", False):
cfg_parallel_size = get_classifier_free_guidance_world_size()
if cfg_parallel_size > 1:
cfg_rank = get_classifier_free_guidance_rank()
cache_branch = "negative" if cfg_rank > 0 else "positive"
else:
# No CFG-parallel: use forward counter to alternate branches
cache_branch = "negative" if self._forward_cnt % 2 == 1 else "positive"
else:
cache_branch = "positive"
context_name = f"teacache_{cache_branch}"
self.state_manager.set_context(context_name)
state = self.state_manager.get_state()
# Decide whether to compute or cache based on modulated input similarity
should_compute = self._should_compute_full_transformer(state, ctx.modulated_input)
if not should_compute and state.previous_residual is not None:
# ============================================================================
# FAST PATH: Reuse cached residuals
# ============================================================================
ctx.hidden_states = ctx.hidden_states + state.previous_residual
if state.previous_residual_encoder is not None and ctx.encoder_hidden_states is not None:
ctx.encoder_hidden_states = ctx.encoder_hidden_states + state.previous_residual_encoder
output = ctx.hidden_states
else:
# ============================================================================
# SLOW PATH: Full transformer computation
# ============================================================================
ori_hidden_states = ctx.hidden_states.clone()
ori_encoder_hidden_states = (
ctx.encoder_hidden_states.clone() if ctx.encoder_hidden_states is not None else None
)
# Run transformer blocks using model-specific callable
outputs = ctx.run_transformer_blocks()
# Update context with outputs
ctx.hidden_states = outputs[0]
if len(outputs) > 1 and ctx.encoder_hidden_states is not None:
ctx.encoder_hidden_states = outputs[1]
# Cache residuals for next timestep
state.previous_residual = (ctx.hidden_states - ori_hidden_states).detach()
if ori_encoder_hidden_states is not None:
state.previous_residual_encoder = (ctx.encoder_hidden_states - ori_encoder_hidden_states).detach()
output = ctx.hidden_states
# Update state
state.previous_modulated_input = ctx.modulated_input.detach()
state.cnt += 1
self._forward_cnt += 1
# ============================================================================
# POSTPROCESSING (model-specific, via callable)
# ============================================================================
return ctx.postprocess(output)
def _should_compute_full_transformer(self, state: TeaCacheState, modulated_inp: torch.Tensor) -> bool:
"""
Determine whether to compute full transformer or reuse cached residual.
This implements the core TeaCache algorithm:
1. Always compute first timestep
2. For intermediate steps:
- Compute relative L1 distance between current and previous modulated inputs
- Apply polynomial rescaling with model-specific coefficients
- Accumulate rescaled distances
- Compare to threshold: below = cache, above = compute
Args:
state: Current TeaCacheState containing counters and cached values
modulated_inp: Modulated input extracted from first transformer block
Returns:
True to compute full transformer, False to reuse cached residual
"""
# First timestep: always compute
if state.cnt == 0:
state.accumulated_rel_l1_distance = 0.0
return True
# Need previous input for comparison
if state.previous_modulated_input is None:
return True
# Compute relative L1 distance between consecutive modulated inputs
rel_distance = (
(
(modulated_inp - state.previous_modulated_input).abs().mean()
/ (state.previous_modulated_input.abs().mean() + 1e-8)
)
.cpu()
.item()
)
# Apply model-specific polynomial rescaling
rescaled_distance = float(self.rescale_func(rel_distance))
state.accumulated_rel_l1_distance += abs(rescaled_distance)
# Decision: below threshold = cache, above = compute
if state.accumulated_rel_l1_distance < self.config.rel_l1_thresh:
return False # Use cache
else:
state.accumulated_rel_l1_distance = 0.0 # Reset accumulator
return True # Compute
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
"""
Reset all cached states for a new inference run.
Args:
module: The module to reset state for.
Returns:
The module with reset state.
"""
self.state_manager.reset()
self._forward_cnt = 0
return module
def apply_teacache_hook(module: torch.nn.Module, config: TeaCacheConfig) -> None:
"""
Apply TeaCache optimization to a transformer module.
This function registers a TeaCacheHook that completely intercepts the
module's forward pass, implementing adaptive caching without any changes
to the model code.
Args:
module: Transformer model to optimize (e.g., QwenImageTransformer2DModel)
config: TeaCacheConfig specifying caching parameters
Example:
>>> config = TeaCacheConfig(
... rel_l1_thresh=0.2,
... transformer_type="QwenImageTransformer2DModel"
... )
>>> apply_teacache_hook(transformer, config)
>>> # Transformer bound to the pipeline now uses TeaCache automatically,
... # no code changes needed!
"""
registry = HookRegistry.get_or_create(module)
hook = TeaCacheHook(config)
registry.register_hook(TeaCacheHook._HOOK_NAME, hook)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
TeaCache state management.
This module manages the state for TeaCache hooks across diffusion timesteps.
"""
import torch
class TeaCacheState:
"""
State management for TeaCache hook.
Tracks caching state across diffusion timesteps, managing counters,
accumulated distances, and cached residuals for the TeaCache algorithm.
"""
def __init__(self):
"""Initialize empty TeaCache state."""
# Timestep tracking
self.cnt = 0
# Caching state
self.accumulated_rel_l1_distance = 0.0
self.previous_modulated_input: torch.Tensor | None = None
self.previous_residual: torch.Tensor | None = None
self.previous_residual_encoder: torch.Tensor | None = None
def reset(self) -> None:
"""Reset all state variables for a new inference run."""
self.cnt = 0
self.accumulated_rel_l1_distance = 0.0
self.previous_modulated_input = None
self.previous_residual = None
self.previous_residual_encoder = None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch.nn as nn
from vllm.logger import init_logger
logger = init_logger(__name__)
def regionally_compile(model: nn.Module, *compile_args: Any, **compile_kwargs: Any) -> nn.Module:
"""
Apply regional compilation to a PyTorch model.
Args:
model: The PyTorch model instance to compile
*compile_args: Positional arguments forwarded to torch.compile
**compile_kwargs: Keyword arguments forwarded to torch.compile
Returns:
The same model instance (modified in-place)
"""
# Get the list of repeated blocks from the model
repeated_blocks = getattr(model, "_repeated_blocks", None)
if not repeated_blocks:
logger.warning("Regional compilation skipped because the model does not define `_repeated_blocks`.")
return model
# Check if we have modules with the specified class names
has_compiled_region = False
for submod in model.modules():
if submod.__class__.__name__ in repeated_blocks:
# Compile this submodule
submod.compile(*compile_args, **compile_kwargs)
has_compiled_region = True
if not has_compiled_region:
logger.warning(f"Regional compilation skipped because {repeated_blocks} classes are not found in the model.")
return model
# adapted from sglang and fastvideo
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import os
import random
from collections.abc import Callable
from dataclasses import dataclass, field, fields
from typing import Any
import torch
from pydantic import model_validator
from typing_extensions import Self
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm_omni.diffusion.utils.network_utils import is_port_available
logger = init_logger(__name__)
@config
@dataclass
class DiffusionParallelConfig:
"""Configuration for diffusion model distributed execution."""
pipeline_parallel_size: int = 1
"""Number of pipeline parallel stages."""
data_parallel_size: int = 1
"""Number of data parallel groups."""
tensor_parallel_size: int = 1
"""Number of tensor parallel groups."""
sequence_parallel_size: int | None = None
"""Number of sequence parallel groups. sequence_parallel_size = ring_degree * ulysses_degree"""
ulysses_degree: int = 1
"""Number of GPUs used for ulysses sequence parallelism."""
ring_degree: int = 1
"""Number of GPUs used for ring sequence parallelism."""
cfg_parallel_size: int = 1
"""Number of Classifier Free Guidance (CFG) parallel groups."""
@model_validator(mode="after")
def _validate_parallel_config(self) -> Self:
"""Validates the config relationships among the parallel strategies."""
assert self.pipeline_parallel_size > 0, "Pipeline parallel size must be > 0"
assert self.data_parallel_size > 0, "Data parallel size must be > 0"
assert self.tensor_parallel_size > 0, "Tensor parallel size must be > 0"
assert self.sequence_parallel_size > 0, "Sequence parallel size must be > 0"
assert self.ulysses_degree > 0, "Ulysses degree must be > 0"
assert self.ring_degree > 0, "Ring degree must be > 0"
assert self.cfg_parallel_size > 0, "CFG parallel size must be > 0"
assert self.cfg_parallel_size in [1, 2], f"CFG parallel size must be 1 or 2, but got {self.cfg_parallel_size}"
assert self.sequence_parallel_size == self.ulysses_degree * self.ring_degree, (
"Sequence parallel size must be equal to the product of ulysses degree and ring degree,"
f" but got {self.sequence_parallel_size} != {self.ulysses_degree} * {self.ring_degree}"
)
return self
def __post_init__(self) -> None:
if self.sequence_parallel_size is None:
self.sequence_parallel_size = self.ulysses_degree * self.ring_degree
self.world_size = (
self.pipeline_parallel_size
* self.data_parallel_size
* self.tensor_parallel_size
* self.ulysses_degree
* self.ring_degree
* self.cfg_parallel_size
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "DiffusionParallelConfig":
"""
Create DiffusionParallelConfig from a dictionary.
Args:
data: Dictionary containing parallel configuration parameters
Returns:
DiffusionParallelConfig instance with parameters set from dict
"""
if not isinstance(data, dict):
raise TypeError(f"Expected parallel config dict, got {type(data)!r}")
return cls(**data)
@dataclass
class TransformerConfig:
"""Container for raw transformer configuration dictionaries."""
params: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "TransformerConfig":
if not isinstance(data, dict):
raise TypeError(f"Expected transformer config dict, got {type(data)!r}")
return cls(params=dict(data))
def to_dict(self) -> dict[str, Any]:
return dict(self.params)
def get(self, key: str, default: Any | None = None) -> Any:
return self.params.get(key, default)
def __getattr__(self, item: str) -> Any:
params = object.__getattribute__(self, "params")
try:
return params[item]
except KeyError as exc:
raise AttributeError(item) from exc
@dataclass
class DiffusionCacheConfig:
"""
Configuration for cache adapters (TeaCache, cache-dit, etc.).
This dataclass provides a unified interface for cache configuration parameters.
It can be initialized from a dictionary and accessed via attributes.
Common parameters:
- TeaCache: rel_l1_thresh, coefficients (optional)
- cache-dit: Fn_compute_blocks, Bn_compute_blocks, max_warmup_steps,
residual_diff_threshold, enable_taylorseer, taylorseer_order,
scm_steps_mask_policy, scm_steps_policy
Example:
>>> # From dict (user-facing API) - partial config uses defaults for missing keys
>>> config = DiffusionCacheConfig.from_dict({"rel_l1_thresh": 0.3})
>>> # Access via attribute
>>> print(config.rel_l1_thresh) # 0.3 (from dict)
>>> print(config.Fn_compute_blocks) # 8 (default)
>>> # Empty dict uses all defaults
>>> default_config = DiffusionCacheConfig.from_dict({})
>>> print(default_config.rel_l1_thresh) # 0.2 (default)
"""
# TeaCache parameters [tea_cache only]
# Default: 0.2 provides ~1.5x speedup with minimal quality loss (optimal balance)
rel_l1_thresh: float = 0.2
coefficients: list[float] | None = None # Uses model-specific defaults if None
# cache-dit parameters [cache-dit only]
# Default: 1 forward compute block (optimized for single-transformer models)
# Use 1 as default instead of cache-dit's 8, optimized for single-transformer models
# This provides better performance while maintaining quality for most use cases
Fn_compute_blocks: int = 1
# Default: 0 backward compute blocks (no fusion by default)
Bn_compute_blocks: int = 0
# Default: 4 warmup steps (optimized for few-step distilled models like Z-Image with 8 steps)
# Use 4 as default warmup steps instead of 8 in cache-dit, making DBCache work
# for few-step distilled models (e.g., Z-Image with 8 steps)
max_warmup_steps: int = 4
# Default: -1 (unlimited cached steps) - DBCache disables caching when previous cached steps exceed this value
# to prevent precision degradation. Set to -1 for unlimited caching (cache-dit default).
max_cached_steps: int = -1
# Default: 0.24 residual difference threshold (higher for more aggressive caching)
# Use a relatively higher residual diff threshold (0.24) as default to allow more
# aggressive caching. This is safe because we have max_continuous_cached_steps limit.
# Without this limit, a lower threshold like 0.12 would be needed.
residual_diff_threshold: float = 0.24
# Default: Limit consecutive cached steps to 3 to prevent precision degradation
# This allows us to use a higher residual_diff_threshold for more aggressive caching
max_continuous_cached_steps: int = 3
# Default: Disable TaylorSeer (not suitable for few-step distilled models)
# TaylorSeer is not suitable for few-step distilled models, so we disable it by default.
# References:
# - From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
# - Forecast then Calibrate: Feature Caching as ODE for Efficient Diffusion Transformers
enable_taylorseer: bool = False
# Default: 1st order TaylorSeer polynomial
taylorseer_order: int = 1
# Default: None SCM mask policy (disabled by default)
scm_steps_mask_policy: str | None = None
# Default: "dynamic" steps policy for adaptive caching
scm_steps_policy: str = "dynamic"
# Used by cache-dit for scm mask generation. If this value changes during inference,
# we will re-generate the scm mask and refresh the cache context.
num_inference_steps: int | None = None
# Additional parameters that may be passed but not explicitly defined
_extra_params: dict[str, Any] = field(default_factory=dict, repr=False)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "DiffusionCacheConfig":
"""
Create DiffusionCacheConfig from a dictionary.
Args:
data: Dictionary containing cache configuration parameters
Returns:
DiffusionCacheConfig instance with parameters set from dict
"""
if not isinstance(data, dict):
raise TypeError(f"Expected cache config dict, got {type(data)!r}")
# Get all dataclass field names automatically
field_names = {f.name for f in fields(cls)}
# Extract parameters that match dataclass fields (excluding private fields)
known_params = {k: v for k, v in data.items() if k in field_names and not k.startswith("_")}
# Store extra parameters
extra_params = {k: v for k, v in data.items() if k not in field_names}
# Create instance with known params (missing ones will use defaults)
# Then update _extra_params after creation since it's a private field
instance = cls(**known_params, _extra_params=extra_params)
return instance
def __getattr__(self, item: str) -> Any:
"""
Allow access to extra parameters via attribute access.
This enables accessing parameters that weren't explicitly defined
in the dataclass fields but were passed in the dict.
"""
if item == "_extra_params" or item.startswith("_"):
return object.__getattribute__(self, item)
extra = object.__getattribute__(self, "_extra_params")
if item in extra:
return extra[item]
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
@dataclass
class OmniDiffusionConfig:
# Model and path configuration (for convenience)
model: str | None = None
model_class_name: str | None = None
dtype: torch.dtype = torch.bfloat16
tf_model_config: TransformerConfig = field(default_factory=TransformerConfig)
# Attention
attention_backend: str | None = None
# Running mode
# mode: ExecutionMode = ExecutionMode.INFERENCE
# Workload type
# workload_type: WorkloadType = WorkloadType.T2V
# Cache strategy (legacy)
cache_strategy: str = "none"
parallel_config: DiffusionParallelConfig = field(default_factory=DiffusionParallelConfig)
# Cache backend configuration (NEW)
cache_backend: str = "none" # "tea_cache", "deep_cache", etc.
cache_config: DiffusionCacheConfig | dict[str, Any] = field(default_factory=dict)
enable_cache_dit_summary: bool = False
# Distributed executor backend
distributed_executor_backend: str = "mp"
nccl_port: int | None = None
# HuggingFace specific parameters
trust_remote_code: bool = False
revision: str | None = None
num_gpus: int | None = None
hsdp_replicate_dim: int = 1
hsdp_shard_dim: int = -1
dist_timeout: int | None = None # timeout for torch.distributed
# pipeline_config: PipelineConfig = field(default_factory=PipelineConfig, repr=False)
# LoRA parameters
lora_path: str | None = None
lora_scale: float = 1.0
max_cpu_loras: int | None = None
output_type: str = "pil"
# CPU offload parameters
# When enabled, DiT and encoders swap GPU access (mutual exclusion):
# - Text encoders run on GPU while DiT is on CPU
# - DiT runs on GPU while encoders are on CPU
enable_cpu_offload: bool = False
# Layer-wise offloading (block-level offloading) parameters
enable_layerwise_offload: bool = False
# Number of transformer blocks ready for computation to keep on GPU
layerwise_num_gpu_layers: int = 1
use_fsdp_inference: bool = False
pin_cpu_memory: bool = True # Use pinned memory for faster transfers when offloading
# VAE memory optimization parameters
vae_use_slicing: bool = False
vae_use_tiling: bool = False
# STA (Sliding Tile Attention) parameters
mask_strategy_file_path: str | None = None
# STA_mode: STA_Mode = STA_Mode.STA_INFERENCE
skip_time_steps: int = 15
# Compilation
enforce_eager: bool = False
# Enable sleep mode
enable_sleep_mode: bool = False
disable_autocast: bool = False
# VSA parameters
VSA_sparsity: float = 0.0 # inference/validation sparsity
# V-MoBA parameters
moba_config_path: str | None = None
# moba_config: dict[str, Any] = field(default_factory=dict)
# Master port for distributed inference
# TODO: do not hard code
master_port: int | None = None
# http server endpoint config, would be ignored in local mode
host: str | None = None
port: int | None = None
scheduler_port: int = 5555
# Stage verification
enable_stage_verification: bool = True
# Prompt text file for batch processing
prompt_file_path: str | None = None
# model paths for correct deallocation
model_paths: dict[str, str] = field(default_factory=dict)
model_loaded: dict[str, bool] = field(
default_factory=lambda: {
"transformer": True,
"vae": True,
}
)
override_transformer_cls_name: str | None = None
# # DMD parameters
# dmd_denoising_steps: List[int] | None = field(default=None)
# MoE parameters used by Wan2.2
boundary_ratio: float | None = None
# Scheduler flow_shift for Wan2.2 (12.0 for 480p, 5.0 for 720p)
flow_shift: float | None = None
# support multi images input
supports_multimodal_inputs: bool = False
# Logging
log_level: str = "info"
# Omni configuration (injected from stage config)
omni_kv_config: dict[str, Any] = field(default_factory=dict)
def settle_port(self, port: int, port_inc: int = 42, max_attempts: int = 100) -> int:
"""
Find an available port with retry logic.
Args:
port: Initial port to check
port_inc: Port increment for each attempt
max_attempts: Maximum number of attempts to find an available port
Returns:
An available port number
Raises:
RuntimeError: If no available port is found after max_attempts
"""
attempts = 0
original_port = port
while attempts < max_attempts:
if is_port_available(port):
if attempts > 0:
logger.info(f"Port {original_port} was unavailable, using port {port} instead")
return port
attempts += 1
if port < 60000:
port += port_inc
else:
# Wrap around with randomization to avoid collision
port = 5000 + random.randint(0, 1000)
raise RuntimeError(
f"Failed to find available port after {max_attempts} attempts (started from port {original_port})"
)
def __post_init__(self):
# TODO: remove hard code
initial_master_port = (self.master_port or 30005) + random.randint(0, 100)
self.master_port = self.settle_port(initial_master_port, 37)
# Convert parallel_config dict to DiffusionParallelConfig if needed
# This must be done before accessing parallel_config.world_size
if isinstance(self.parallel_config, dict):
self.parallel_config = DiffusionParallelConfig.from_dict(self.parallel_config)
elif not isinstance(self.parallel_config, DiffusionParallelConfig):
# If it's neither dict nor DiffusionParallelConfig, use default config
self.parallel_config = DiffusionParallelConfig()
if self.num_gpus is None:
if self.parallel_config is not None:
self.num_gpus = self.parallel_config.world_size
else:
self.num_gpus = 1
if self.num_gpus < self.parallel_config.world_size:
raise ValueError(
f"num_gpus ({self.num_gpus}) < parallel_config.world_size ({self.parallel_config.world_size})"
)
# Convert string dtype to torch.dtype if needed
if isinstance(self.dtype, str):
dtype_map = {
"auto": torch.bfloat16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
"float16": torch.float16,
"fp16": torch.float16,
"half": torch.float16,
"float32": torch.float32,
"fp32": torch.float32,
"float": torch.float32,
}
dtype_lower = self.dtype.lower()
if dtype_lower in dtype_map:
self.dtype = dtype_map[dtype_lower]
else:
logger.warning(f"Unknown dtype string '{self.dtype}', defaulting to bfloat16")
self.dtype = torch.bfloat16
# Convert cache_config dict to DiffusionCacheConfig if needed
if isinstance(self.cache_config, dict):
self.cache_config = DiffusionCacheConfig.from_dict(self.cache_config)
elif not isinstance(self.cache_config, DiffusionCacheConfig):
# If it's neither dict nor DiffusionCacheConfig, convert to empty config
self.cache_config = DiffusionCacheConfig()
if self.max_cpu_loras is None:
self.max_cpu_loras = 1
elif self.max_cpu_loras < 1:
raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA")
def update_multimodal_support(self) -> None:
self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"}
@classmethod
def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig":
# Backwards-compatibility: older callers may use a diffusion-specific
# "static_lora_scale" kwarg. Normalize it to the canonical "lora_scale"
# before constructing the dataclass to avoid TypeError on unknown fields.
if "static_lora_scale" in kwargs:
if "lora_scale" not in kwargs:
kwargs["lora_scale"] = kwargs["static_lora_scale"]
kwargs.pop("static_lora_scale", None)
# Check environment variable as fallback for cache_backend
# Support both old DIFFUSION_CACHE_ADAPTER and new DIFFUSION_CACHE_BACKEND for backwards compatibility
if "cache_backend" not in kwargs:
cache_backend = os.environ.get("DIFFUSION_CACHE_BACKEND") or os.environ.get("DIFFUSION_CACHE_ADAPTER")
kwargs["cache_backend"] = cache_backend.lower() if cache_backend else "none"
# Filter kwargs to only include valid fields
valid_fields = {f.name for f in fields(cls)}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_fields}
return cls(**filtered_kwargs)
@dataclass
class DiffusionOutput:
"""
Final output (after pipeline completion)
"""
output: torch.Tensor | None = None
trajectory_timesteps: list[torch.Tensor] | None = None
trajectory_latents: torch.Tensor | None = None
trajectory_decoded: list[torch.Tensor] | None = None
error: str | None = None
post_process_func: Callable[..., Any] | None = None
# logged timings info, directly from Req.timings
# timings: Optional["RequestTimings"] = None
class AttentionBackendEnum(enum.Enum):
FA = enum.auto()
SLIDING_TILE_ATTN = enum.auto()
TORCH_SDPA = enum.auto()
SAGE_ATTN = enum.auto()
SAGE_ATTN_THREE = enum.auto()
VIDEO_SPARSE_ATTN = enum.auto()
VMOBA_ATTN = enum.auto()
AITER = enum.auto()
NO_ATTENTION = enum.auto()
def __str__(self):
return self.name.lower()
# Special message broadcast via scheduler queues to signal worker shutdown.
SHUTDOWN_MESSAGE = {"type": "shutdown"}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment