Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm-omni
Commits
c1cacde6
Commit
c1cacde6
authored
Mar 25, 2026
by
weishb
Browse files
vllm-omni_0.15.0.rc1+fix1 first commit
parent
35607782
Changes
306
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1604 additions
and
0 deletions
+1604
-0
vllm_omni/diffusion/cache/teacache/config.py
vllm_omni/diffusion/cache/teacache/config.py
+84
-0
vllm_omni/diffusion/cache/teacache/extractors.py
vllm_omni/diffusion/cache/teacache/extractors.py
+650
-0
vllm_omni/diffusion/cache/teacache/hook.py
vllm_omni/diffusion/cache/teacache/hook.py
+272
-0
vllm_omni/diffusion/cache/teacache/state.py
vllm_omni/diffusion/cache/teacache/state.py
+38
-0
vllm_omni/diffusion/compile.py
vllm_omni/diffusion/compile.py
+41
-0
vllm_omni/diffusion/data.py
vllm_omni/diffusion/data.py
+519
-0
No files found.
Too many changes to show.
To preserve performance only
306 of 306+
files are displayed.
Plain diff
Email patch
vllm_omni/diffusion/cache/teacache/config.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
# Model-specific polynomial coefficients for rescaling L1 distances
# These coefficients account for model-specific characteristics in how embeddings change
# Source: TeaCache paper and ComfyUI-TeaCache empirical tuning
_MODEL_COEFFICIENTS
=
{
# FLUX transformer coefficients from TeaCache paper
"FluxTransformer2DModel"
:
[
4.98651651e02
,
-
2.83781631e02
,
5.58554382e01
,
-
3.82021401e00
,
2.64230861e-01
,
],
# Qwen-Image transformer coefficients from ComfyUI-TeaCache
# Tuned specifically for Qwen's dual-stream transformer architecture
# Used for all Qwen-Image Family pipelines, in general
"QwenImageTransformer2DModel"
:
[
-
4.50000000e02
,
2.80000000e02
,
-
4.50000000e01
,
3.20000000e00
,
-
2.00000000e-02
,
],
# Bagel transformer coefficients
# Using Qwen's coefficients as reasonable default given shared architecture
"Bagel"
:
[
1.33313129e06
,
-
1.68644226e05
,
7.95050740e03
,
-
1.63747873e02
,
1.26352397e00
],
# Z-Image transformer coefficients
# Copied from Qwen-Image, need to be tuned specifically for Z-Image in future
"ZImageTransformer2DModel"
:
[
-
4.50000000e02
,
2.80000000e02
,
-
4.50000000e01
,
3.20000000e00
,
-
2.00000000e-02
,
],
}
@
dataclass
class
TeaCacheConfig
:
"""
Configuration for TeaCache applied to transformer models.
TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up
diffusion model inference by reusing transformer block computations when consecutive
timestep embeddings are similar.
Args:
rel_l1_thresh: Threshold for accumulated relative L1 distance. When below threshold,
cached residual is reused. Values in [0.1, 0.3] work best:
- 0.2: ~1.5x speedup with minimal quality loss
- 0.4: ~1.8x speedup with slight quality loss
- 0.6: ~2.0x speedup with noticeable quality loss
coefficients: Polynomial coefficients for rescaling L1 distance. If None, uses
model-specific defaults based on transformer_type.
transformer_type: Transformer class name (e.g., "QwenImageTransformer2DModel").
Auto-detected from pipeline.transformer.__class__.__name__ in backend.
Defaults to "QwenImageTransformer2DModel".
"""
rel_l1_thresh
:
float
=
0.2
coefficients
:
list
[
float
]
|
None
=
None
transformer_type
:
str
=
"QwenImageTransformer2DModel"
def
__post_init__
(
self
)
->
None
:
"""Validate and set default coefficients."""
if
self
.
rel_l1_thresh
<=
0
:
raise
ValueError
(
f
"rel_l1_thresh must be positive, got
{
self
.
rel_l1_thresh
}
"
)
if
self
.
coefficients
is
None
:
# Use model-specific coefficients, explicitly check if the type exists or not
if
self
.
transformer_type
not
in
_MODEL_COEFFICIENTS
:
raise
KeyError
(
f
"Cannot find coefficients for
{
self
.
transformer_type
}
. "
f
"Supported:
{
list
(
_MODEL_COEFFICIENTS
.
keys
())
}
"
)
self
.
coefficients
=
_MODEL_COEFFICIENTS
[
self
.
transformer_type
]
if
len
(
self
.
coefficients
)
!=
5
:
raise
ValueError
(
f
"coefficients must contain exactly 5 elements, got
{
len
(
self
.
coefficients
)
}
"
)
vllm_omni/diffusion/cache/teacache/extractors.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Model-specific extractors for TeaCache.
This module provides a registry of extractor functions that know how to extract
modulated inputs from different transformer architectures. Adding support for
a new model requires only adding a new extractor function to the registry.
With Option B enhancement, extractors now return a CacheContext object containing
all model-specific information needed for generic caching, including preprocessing,
transformer execution, and postprocessing logic.
"""
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
typing
import
Any
import
torch
import
torch.nn
as
nn
from
vllm_omni.diffusion.forward_context
import
get_forward_context
@
dataclass
class
CacheContext
:
"""
Context object containing all model-specific information for caching.
This allows the TeaCacheHook to remain completely generic - all model-specific
logic is encapsulated in the extractor that returns this context.
Attributes:
modulated_input: Tensor used for cache decision (similarity comparison).
Must be a torch.Tensor extracted from the first transformer block,
typically after applying normalization and modulation.
hidden_states: Current hidden states (will be modified by caching).
Must be a torch.Tensor representing the main image/latent states
after preprocessing but before transformer blocks.
encoder_hidden_states: Optional encoder states (for dual-stream models).
Set to None for single-stream models (e.g., Flux).
For dual-stream models (e.g., Qwen), contains text encoder outputs.
temb: Timestep embedding tensor.
Must be a torch.Tensor containing the timestep conditioning.
run_transformer_blocks: Callable that executes model-specific transformer blocks.
Signature: () -> tuple[torch.Tensor, ...]
Returns:
tuple containing:
- [0]: processed hidden_states (required)
- [1]: processed encoder_hidden_states (optional, only for dual-stream)
Example for single-stream:
def run_blocks():
h = hidden_states
for block in module.transformer_blocks:
h = block(h, temb=temb)
return (h,)
Example for dual-stream:
def run_blocks():
h, e = hidden_states, encoder_hidden_states
for block in module.transformer_blocks:
e, h = block(h, e, temb=temb)
return (h, e)
postprocess: Callable that does model-specific output postprocessing.
Signature: (torch.Tensor) -> Union[torch.Tensor, Transformer2DModelOutput, tuple]
Takes the processed hidden_states and applies final transformations
(normalization, projection) to produce the model output.
Example:
def postprocess(h):
h = module.norm_out(h, temb)
output = module.proj_out(h)
return Transformer2DModelOutput(sample=output)
extra_states: Optional dict for additional model-specific state.
Use this for models that need to pass additional context beyond
the standard fields.
"""
modulated_input
:
torch
.
Tensor
hidden_states
:
torch
.
Tensor
encoder_hidden_states
:
torch
.
Tensor
|
None
temb
:
torch
.
Tensor
run_transformer_blocks
:
Callable
[[],
tuple
[
torch
.
Tensor
,
...]]
postprocess
:
Callable
[[
torch
.
Tensor
],
Any
]
extra_states
:
dict
[
str
,
Any
]
|
None
=
None
def
validate
(
self
)
->
None
:
"""
Validate that the CacheContext contains valid data.
Raises:
TypeError: If fields have wrong types
ValueError: If tensors have invalid properties
RuntimeError: If callables fail basic invocation tests
This method should be called after creating a CacheContext to catch
common developer errors early with clear error messages.
"""
# Validate tensor fields
if
not
isinstance
(
self
.
modulated_input
,
torch
.
Tensor
):
raise
TypeError
(
f
"modulated_input must be torch.Tensor, got
{
type
(
self
.
modulated_input
)
}
"
)
if
not
isinstance
(
self
.
hidden_states
,
torch
.
Tensor
):
raise
TypeError
(
f
"hidden_states must be torch.Tensor, got
{
type
(
self
.
hidden_states
)
}
"
)
if
self
.
encoder_hidden_states
is
not
None
and
not
isinstance
(
self
.
encoder_hidden_states
,
torch
.
Tensor
):
raise
TypeError
(
f
"encoder_hidden_states must be torch.Tensor or None, got
{
type
(
self
.
encoder_hidden_states
)
}
"
)
if
not
isinstance
(
self
.
temb
,
torch
.
Tensor
):
raise
TypeError
(
f
"temb must be torch.Tensor, got
{
type
(
self
.
temb
)
}
"
)
# Validate callables
if
not
callable
(
self
.
run_transformer_blocks
):
raise
TypeError
(
f
"run_transformer_blocks must be callable, got
{
type
(
self
.
run_transformer_blocks
)
}
"
)
if
not
callable
(
self
.
postprocess
):
raise
TypeError
(
f
"postprocess must be callable, got
{
type
(
self
.
postprocess
)
}
"
)
# Validate tensor shapes are compatible
if
self
.
modulated_input
.
shape
[
0
]
!=
self
.
hidden_states
.
shape
[
0
]:
raise
ValueError
(
f
"Batch size mismatch: modulated_input has batch size "
f
"
{
self
.
modulated_input
.
shape
[
0
]
}
, but hidden_states has "
f
"
{
self
.
hidden_states
.
shape
[
0
]
}
"
)
# Validate devices match
if
self
.
modulated_input
.
device
!=
self
.
hidden_states
.
device
:
raise
ValueError
(
f
"Device mismatch: modulated_input on
{
self
.
modulated_input
.
device
}
, "
f
"hidden_states on
{
self
.
hidden_states
.
device
}
"
)
def
extract_qwen_context
(
module
:
nn
.
Module
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
encoder_hidden_states_mask
:
torch
.
Tensor
,
timestep
:
torch
.
Tensor
|
float
|
int
,
img_shapes
:
torch
.
Tensor
,
txt_seq_lens
:
torch
.
Tensor
,
guidance
:
torch
.
Tensor
|
None
=
None
,
additional_t_cond
:
torch
.
Tensor
|
None
=
None
,
attention_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
**
kwargs
:
Any
,
)
->
CacheContext
:
"""
Extract cache context for QwenImageTransformer2DModel.
This is the ONLY Qwen-specific code needed for TeaCache support.
It encapsulates preprocessing, modulated input extraction, transformer execution,
and postprocessing logic.
Args:
module: QwenImageTransformer2DModel instance
hidden_states: Input hidden states tensor
encoder_hidden_states: Text encoder outputs
encoder_hidden_states_mask: Mask for text encoder
timestep: Current diffusion timestep
img_shapes: Image shapes for position embedding
txt_seq_lens: Text sequence lengths
guidance: Optional guidance scale for CFG
additional_t_cond: Optional additional timestep conditioning
attention_kwargs: Additional attention arguments
**kwargs: Additional keyword arguments ignored by this extractor
Returns:
CacheContext with all information needed for generic caching
"""
from
diffusers.models.modeling_outputs
import
Transformer2DModelOutput
if
not
hasattr
(
module
,
"transformer_blocks"
)
or
len
(
module
.
transformer_blocks
)
==
0
:
raise
ValueError
(
"Module must have transformer_blocks"
)
# ============================================================================
# PREPROCESSING (Qwen-specific)
# ============================================================================
hidden_states
=
module
.
img_in
(
hidden_states
)
timestep
=
timestep
.
to
(
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
encoder_hidden_states
=
module
.
txt_norm
(
encoder_hidden_states
)
encoder_hidden_states
=
module
.
txt_in
(
encoder_hidden_states
)
if
guidance
is
not
None
:
guidance
=
guidance
.
to
(
hidden_states
.
dtype
)
*
1000
temb
=
(
module
.
time_text_embed
(
timestep
,
hidden_states
,
additional_t_cond
)
if
guidance
is
None
else
module
.
time_text_embed
(
timestep
,
guidance
,
hidden_states
,
additional_t_cond
)
)
image_rotary_emb
=
module
.
pos_embed
(
img_shapes
,
txt_seq_lens
,
device
=
hidden_states
.
device
)
# ============================================================================
# EXTRACT MODULATED INPUT (for cache decision)
# ============================================================================
block
=
module
.
transformer_blocks
[
0
]
img_mod_params
=
block
.
img_mod
(
temb
)
img_mod1
,
_
=
img_mod_params
.
chunk
(
2
,
dim
=-
1
)
img_modulated
,
_
=
block
.
img_norm1
(
hidden_states
,
img_mod1
)
# ============================================================================
# DEFINE TRANSFORMER EXECUTION (Qwen-specific)
# ============================================================================
def
run_transformer_blocks
():
"""Execute all Qwen transformer blocks."""
h
=
hidden_states
e
=
encoder_hidden_states
encoder_mask
=
encoder_hidden_states_mask
hidden_states_mask
=
None
# default
if
module
.
parallel_config
is
not
None
and
module
.
parallel_config
.
sequence_parallel_size
>
1
:
ctx
=
get_forward_context
()
if
ctx
.
sp_original_seq_len
is
not
None
and
ctx
.
sp_padding_size
>
0
:
# Create mask for the full (padded) sequence
# valid positions = True, padding positions = False
batch_size
=
hidden_states
.
shape
[
0
]
padded_seq_len
=
ctx
.
sp_original_seq_len
+
ctx
.
sp_padding_size
hidden_states_mask
=
torch
.
ones
(
batch_size
,
padded_seq_len
,
dtype
=
torch
.
bool
,
device
=
hidden_states
.
device
,
)
hidden_states_mask
[:,
ctx
.
sp_original_seq_len
:]
=
False
# if mask is all true, set it to None
if
hidden_states_mask
is
not
None
and
hidden_states_mask
.
all
():
hidden_states_mask
=
None
if
encoder_mask
is
not
None
and
encoder_mask
.
all
():
encoder_mask
=
None
for
block
in
module
.
transformer_blocks
:
e
,
h
=
block
(
hidden_states
=
h
,
encoder_hidden_states
=
e
,
encoder_hidden_states_mask
=
encoder_mask
,
temb
=
temb
,
image_rotary_emb
=
image_rotary_emb
,
joint_attention_kwargs
=
attention_kwargs
,
hidden_states_mask
=
hidden_states_mask
,
)
return
(
h
,
e
)
# ============================================================================
# DEFINE POSTPROCESSING (Qwen-specific)
# ============================================================================
return_dict
=
kwargs
.
get
(
"return_dict"
,
True
)
def
postprocess
(
h
):
"""Apply Qwen-specific output postprocessing."""
h
=
module
.
norm_out
(
h
,
temb
)
output
=
module
.
proj_out
(
h
)
if
not
return_dict
:
return
(
output
,)
return
Transformer2DModelOutput
(
sample
=
output
)
# ============================================================================
# RETURN CONTEXT
# ============================================================================
return
CacheContext
(
modulated_input
=
img_modulated
,
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
temb
=
temb
,
run_transformer_blocks
=
run_transformer_blocks
,
postprocess
=
postprocess
,
)
def
extract_bagel_context
(
module
:
nn
.
Module
,
x_t
:
torch
.
Tensor
,
timestep
:
torch
.
Tensor
|
float
|
int
,
packed_vae_token_indexes
:
torch
.
LongTensor
,
packed_vae_position_ids
:
torch
.
LongTensor
,
packed_text_ids
:
torch
.
LongTensor
,
packed_text_indexes
:
torch
.
LongTensor
,
packed_indexes
:
torch
.
LongTensor
,
packed_position_ids
:
torch
.
LongTensor
,
packed_seqlens
:
torch
.
IntTensor
,
key_values_lens
:
torch
.
IntTensor
,
past_key_values
:
Any
,
packed_key_value_indexes
:
torch
.
LongTensor
,
**
kwargs
:
Any
,
)
->
CacheContext
:
"""
Extract cache context for Bagel model.
Args:
module: Bagel instance
x_t: Latent image input
timestep: Current timestep
packed_vae_token_indexes: Indexes for VAE tokens in packed sequence
packed_vae_position_ids: Position IDs for VAE tokens
packed_text_ids: Text token IDs
packed_text_indexes: Indexes for text tokens in packed sequence
packed_indexes: Global indexes
packed_position_ids: Global position IDs
packed_seqlens: Sequence lengths
key_values_lens: KV cache lengths
past_key_values: KV cache
packed_key_value_indexes: KV cache indexes
**kwargs: Additional keyword arguments
Returns:
CacheContext with all information needed for generic caching
"""
# 1. Embed text
packed_text_embedding
=
module
.
language_model
.
model
.
embed_tokens
(
packed_text_ids
)
packed_sequence
=
packed_text_embedding
.
new_zeros
((
sum
(
packed_seqlens
),
module
.
hidden_size
))
packed_sequence
[
packed_text_indexes
]
=
packed_text_embedding
# 2. Embed timestep
if
not
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
torch
.
tensor
([
timestep
],
device
=
x_t
.
device
)
if
timestep
.
dim
()
==
0
:
timestep
=
timestep
.
unsqueeze
(
0
)
# 3. Embed image (x_t)
packed_pos_embed
=
module
.
latent_pos_embed
(
packed_vae_position_ids
)
packed_timestep_embeds
=
module
.
time_embedder
(
timestep
)
x_t_emb
=
module
.
vae2llm
(
x_t
)
+
packed_timestep_embeds
+
packed_pos_embed
if
x_t_emb
.
dtype
!=
packed_sequence
.
dtype
:
x_t_emb
=
x_t_emb
.
to
(
packed_sequence
.
dtype
)
packed_sequence
[
packed_vae_token_indexes
]
=
x_t_emb
# Use the full packed sequence as modulated input to match hidden_states size
modulated_input
=
packed_sequence
def
run_transformer_blocks
():
extra_inputs
=
{}
if
module
.
use_moe
:
extra_inputs
=
{
"mode"
:
"gen"
,
"packed_vae_token_indexes"
:
packed_vae_token_indexes
,
"packed_text_indexes"
:
packed_text_indexes
,
}
output
=
module
.
language_model
.
forward
(
packed_query_sequence
=
packed_sequence
,
query_lens
=
packed_seqlens
,
packed_query_position_ids
=
packed_position_ids
,
packed_query_indexes
=
packed_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
False
,
is_causal
=
False
,
**
extra_inputs
,
)
return
(
output
.
packed_query_sequence
,)
def
postprocess
(
h
):
v_t
=
module
.
llm2vae
(
h
)
v_t
=
v_t
[
packed_vae_token_indexes
]
return
v_t
return
CacheContext
(
modulated_input
=
modulated_input
,
hidden_states
=
packed_sequence
,
# Use full packed sequence
encoder_hidden_states
=
None
,
temb
=
packed_timestep_embeds
,
# Approximate
run_transformer_blocks
=
run_transformer_blocks
,
postprocess
=
postprocess
,
)
def
extract_zimage_context
(
module
:
nn
.
Module
,
x
:
list
[
torch
.
Tensor
],
t
:
torch
.
Tensor
,
cap_feats
:
list
[
torch
.
Tensor
],
patch_size
:
int
=
2
,
f_patch_size
:
int
=
1
,
**
kwargs
:
Any
,
)
->
CacheContext
:
"""
Extract cache context for ZImageTransformer2DModel.
This is the ONLY Z-Image-specific code needed for TeaCache support.
It encapsulates preprocessing, modulated input extraction, transformer execution,
and postprocessing logic.
Args:
module: ZImageTransformer2DModel instance
x: List of image tensors per batch item
t: Timestep tensor
cap_feats: List of caption feature tensors per batch item
patch_size: Patch size for patchification (default: 2)
f_patch_size: Frame patch size (default: 1)
**kwargs: Additional keyword arguments ignored by this extractor
Returns:
CacheContext with all information needed for generic caching
"""
from
torch.nn.utils.rnn
import
pad_sequence
if
not
hasattr
(
module
,
"layers"
)
or
len
(
module
.
layers
)
==
0
:
raise
ValueError
(
"Module must have main transformer layers"
)
bsz
=
len
(
x
)
device
=
x
[
0
].
device
# ============================================================================
# PREPROCESSING (Z-Image specific)
# ============================================================================
# Scale timestep and create timestep embedding
t_scaled
=
t
*
module
.
t_scale
adaln_input
=
module
.
t_embedder
(
t_scaled
)
# Patchify and embed inputs
(
x_patches
,
cap_feats_processed
,
x_size
,
x_pos_ids
,
cap_pos_ids
,
x_inner_pad_mask
,
cap_inner_pad_mask
,
)
=
module
.
patchify_and_embed
(
x
,
cap_feats
,
patch_size
,
f_patch_size
)
# Process image patches through embedder and noise refiner
x_item_seqlens
=
[
len
(
_
)
for
_
in
x_patches
]
x_max_item_seqlen
=
max
(
x_item_seqlens
)
x_embedded
=
torch
.
cat
(
x_patches
,
dim
=
0
)
x_embedded
=
module
.
all_x_embedder
[
f
"
{
patch_size
}
-
{
f_patch_size
}
"
](
x_embedded
)
# Match adaln_input dtype to x_embedded
adaln_input
=
adaln_input
.
type_as
(
x_embedded
)
# Apply pad token
x_embedded
[
torch
.
cat
(
x_inner_pad_mask
)]
=
module
.
x_pad_token
x_list
=
list
(
x_embedded
.
split
(
x_item_seqlens
,
dim
=
0
))
# Compute rope embeddings for image patches
x_cos
,
x_sin
=
module
.
rope_embedder
(
torch
.
cat
(
x_pos_ids
,
dim
=
0
))
x_cos
=
list
(
x_cos
.
split
(
x_item_seqlens
,
dim
=
0
))
x_sin
=
list
(
x_sin
.
split
(
x_item_seqlens
,
dim
=
0
))
# Pad sequences for batch processing
x_batched
=
pad_sequence
(
x_list
,
batch_first
=
True
,
padding_value
=
0.0
)
x_cos_batched
=
pad_sequence
(
x_cos
,
batch_first
=
True
,
padding_value
=
0.0
)
x_sin_batched
=
pad_sequence
(
x_sin
,
batch_first
=
True
,
padding_value
=
0.0
)
x_attn_mask
=
torch
.
zeros
((
bsz
,
x_max_item_seqlen
),
dtype
=
torch
.
bool
,
device
=
device
)
for
i
,
seq_len
in
enumerate
(
x_item_seqlens
):
x_attn_mask
[
i
,
:
seq_len
]
=
1
# Run noise refiner blocks
for
layer
in
module
.
noise_refiner
:
x_batched
=
layer
(
x_batched
,
x_attn_mask
,
x_cos_batched
,
x_sin_batched
,
adaln_input
)
# Process caption features through embedder and context refiner
cap_item_seqlens
=
[
len
(
_
)
for
_
in
cap_feats_processed
]
cap_max_item_seqlen
=
max
(
cap_item_seqlens
)
cap_embedded
=
torch
.
cat
(
cap_feats_processed
,
dim
=
0
)
cap_embedded
=
module
.
cap_embedder
(
cap_embedded
)
cap_embedded
[
torch
.
cat
(
cap_inner_pad_mask
)]
=
module
.
cap_pad_token
cap_list
=
list
(
cap_embedded
.
split
(
cap_item_seqlens
,
dim
=
0
))
# Compute rope embeddings for caption
cap_cos
,
cap_sin
=
module
.
rope_embedder
(
torch
.
cat
(
cap_pos_ids
,
dim
=
0
))
cap_cos
=
list
(
cap_cos
.
split
(
cap_item_seqlens
,
dim
=
0
))
cap_sin
=
list
(
cap_sin
.
split
(
cap_item_seqlens
,
dim
=
0
))
# Pad sequences for batch processing
cap_batched
=
pad_sequence
(
cap_list
,
batch_first
=
True
,
padding_value
=
0.0
)
cap_cos_batched
=
pad_sequence
(
cap_cos
,
batch_first
=
True
,
padding_value
=
0.0
)
cap_sin_batched
=
pad_sequence
(
cap_sin
,
batch_first
=
True
,
padding_value
=
0.0
)
cap_attn_mask
=
torch
.
zeros
((
bsz
,
cap_max_item_seqlen
),
dtype
=
torch
.
bool
,
device
=
device
)
for
i
,
seq_len
in
enumerate
(
cap_item_seqlens
):
cap_attn_mask
[
i
,
:
seq_len
]
=
1
# Run context refiner blocks
for
layer
in
module
.
context_refiner
:
cap_batched
=
layer
(
cap_batched
,
cap_attn_mask
,
cap_cos_batched
,
cap_sin_batched
)
# Create unified sequence (image + caption)
unified_list
=
[]
unified_cos_list
=
[]
unified_sin_list
=
[]
for
i
in
range
(
bsz
):
x_len
=
x_item_seqlens
[
i
]
cap_len
=
cap_item_seqlens
[
i
]
unified_list
.
append
(
torch
.
cat
([
x_batched
[
i
][:
x_len
],
cap_batched
[
i
][:
cap_len
]]))
unified_cos_list
.
append
(
torch
.
cat
([
x_cos_batched
[
i
][:
x_len
],
cap_cos_batched
[
i
][:
cap_len
]]))
unified_sin_list
.
append
(
torch
.
cat
([
x_sin_batched
[
i
][:
x_len
],
cap_sin_batched
[
i
][:
cap_len
]]))
unified_item_seqlens
=
[
a
+
b
for
a
,
b
in
zip
(
cap_item_seqlens
,
x_item_seqlens
)]
unified_max_item_seqlen
=
max
(
unified_item_seqlens
)
unified
=
pad_sequence
(
unified_list
,
batch_first
=
True
,
padding_value
=
0.0
)
unified_cos
=
pad_sequence
(
unified_cos_list
,
batch_first
=
True
,
padding_value
=
0.0
)
unified_sin
=
pad_sequence
(
unified_sin_list
,
batch_first
=
True
,
padding_value
=
0.0
)
unified_attn_mask
=
torch
.
zeros
((
bsz
,
unified_max_item_seqlen
),
dtype
=
torch
.
bool
,
device
=
device
)
for
i
,
seq_len
in
enumerate
(
unified_item_seqlens
):
unified_attn_mask
[
i
,
:
seq_len
]
=
1
# ============================================================================
# EXTRACT MODULATED INPUT (for cache decision)
# ============================================================================
# Use the first main transformer block's modulation
# The main layers have modulation=True and process the unified sequence
block
=
module
.
layers
[
0
]
# Get modulation parameters: scale_msa, gate_msa, scale_mlp, gate_mlp
mod_params
=
block
.
adaLN_modulation
(
adaln_input
).
unsqueeze
(
1
).
chunk
(
4
,
dim
=
2
)
scale_msa
=
1.0
+
mod_params
[
0
]
# Extract modulated input: normalized hidden states scaled by modulation
modulated_input
=
block
.
attention_norm1
(
unified
)
*
scale_msa
# ============================================================================
# DEFINE TRANSFORMER EXECUTION (Z-Image specific)
# ============================================================================
def
run_transformer_blocks
():
"""Execute all Z-Image main transformer blocks."""
h
=
unified
for
layer
in
module
.
layers
:
h
=
layer
(
h
,
unified_attn_mask
,
unified_cos
,
unified_sin
,
adaln_input
)
return
(
h
,)
# ============================================================================
# DEFINE POSTPROCESSING (Z-Image specific)
# ============================================================================
def
postprocess
(
h
):
"""Apply Z-Image specific output postprocessing."""
h
=
module
.
all_final_layer
[
f
"
{
patch_size
}
-
{
f_patch_size
}
"
](
h
,
adaln_input
)
h
=
list
(
h
.
unbind
(
dim
=
0
))
output
=
module
.
unpatchify
(
h
,
x_size
,
patch_size
,
f_patch_size
)
return
output
,
{}
# ============================================================================
# RETURN CONTEXT
# ============================================================================
return
CacheContext
(
modulated_input
=
modulated_input
,
hidden_states
=
unified
,
encoder_hidden_states
=
None
,
# Z-Image uses unified sequence, no separate encoder states
temb
=
adaln_input
,
run_transformer_blocks
=
run_transformer_blocks
,
postprocess
=
postprocess
,
extra_states
=
{
"unified_attn_mask"
:
unified_attn_mask
,
"unified_cos"
:
unified_cos
,
"unified_sin"
:
unified_sin
,
"x_size"
:
x_size
,
"x_item_seqlens"
:
x_item_seqlens
,
"patch_size"
:
patch_size
,
"f_patch_size"
:
f_patch_size
,
},
)
# Registry for model-specific extractors
# Key: Transformer class name
# Value: extractor function with signature (module, *args, **kwargs) -> CacheContext
#
# Note: Use the transformer class name as specified in pipelines as TeaCache hooks operate
# on the transformer module and multiple pipelines can share the same transformer.
EXTRACTOR_REGISTRY
:
dict
[
str
,
Callable
]
=
{
"QwenImageTransformer2DModel"
:
extract_qwen_context
,
"Bagel"
:
extract_bagel_context
,
"ZImageTransformer2DModel"
:
extract_zimage_context
,
# Future models:
# "FluxTransformer2DModel": extract_flux_context,
# "CogVideoXTransformer3DModel": extract_cogvideox_context,
}
def
register_extractor
(
transformer_cls_name
:
str
,
extractor_fn
:
Callable
)
->
None
:
"""
Register a new extractor function for a model type.
This allows extending TeaCache support to new models without modifying
the core TeaCache code.
Args:
transformer_cls_name: Transformer model type identifier (class name or type string)
extractor_fn: Function with signature (module, *args, **kwargs) -> CacheContext
Example:
>>> def extract_flux_context(module, hidden_states, timestep, guidance=None, **kwargs):
... # Preprocessing
... temb = module.time_text_embed(timestep, guidance)
... # Extract modulated input
... modulated = module.transformer_blocks[0].norm1(hidden_states, emb=temb)
... # Define execution
... def run_blocks():
... h = hidden_states
... for block in module.transformer_blocks:
... h = block(h, temb=temb)
... return (h,)
... # Define postprocessing
... def postprocess(h):
... return module.proj_out(module.norm_out(h, temb))
... # Return context
... return CacheContext(modulated, hidden_states, None, temb, run_blocks, postprocess)
>>> register_extractor("FluxTransformer2DModel", extract_flux_context)
"""
EXTRACTOR_REGISTRY
[
transformer_cls_name
]
=
extractor_fn
def
get_extractor
(
transformer_cls_name
:
str
)
->
Callable
:
"""
Get extractor function for given transformer class.
This function looks up the extractor based on the exact transformer_cls_name string,
which should match the transformer type in the pipeline (i.e., pipeline.transformer.__class__.__name__).
Args:
transformer_cls_name: Transformer class name (e.g., "QwenImageTransformer2DModel")
Must exactly match a key in EXTRACTOR_REGISTRY.
Returns:
Extractor function with signature (module, *args, **kwargs) -> CacheContext
Raises:
ValueError: If model type not found in registry
Example:
>>> # Get extractor for QwenImageTransformer2DModel
>>> extractor = get_extractor("QwenImageTransformer2DModel")
>>> ctx = extractor(transformer, hidden_states, encoder_hidden_states, timestep, ...)
"""
# Direct lookup - no substring matching
if
transformer_cls_name
in
EXTRACTOR_REGISTRY
:
return
EXTRACTOR_REGISTRY
[
transformer_cls_name
]
# No match found
available_types
=
list
(
EXTRACTOR_REGISTRY
.
keys
())
raise
ValueError
(
f
"Unknown model type: '
{
transformer_cls_name
}
'. "
f
"Available types:
{
available_types
}
\n
"
f
"To add support for a new model, use register_extractor() or add to EXTRACTOR_REGISTRY."
)
vllm_omni/diffusion/cache/teacache/hook.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Hook-based TeaCache implementation for vLLM-Omni.
This module implements a diffusers-style hook system that completely intercepts
the transformer forward pass, eliminating the need for any TeaCache-specific
code in model definitions. Model developers only need to add an extractor function
to support new models.
"""
from
__future__
import
annotations
from
typing
import
Any
import
numpy
as
np
import
torch
from
vllm_omni.diffusion.cache.teacache.config
import
TeaCacheConfig
from
vllm_omni.diffusion.cache.teacache.extractors
import
get_extractor
from
vllm_omni.diffusion.cache.teacache.state
import
TeaCacheState
from
vllm_omni.diffusion.distributed.parallel_state
import
(
get_classifier_free_guidance_rank
,
get_classifier_free_guidance_world_size
,
)
from
vllm_omni.diffusion.hooks
import
HookRegistry
,
ModelHook
,
StateManager
class
TeaCacheHook
(
ModelHook
):
"""
ModelHook implementing TeaCache for transformer models.
This hook completely intercepts the transformer's forward pass and implements
adaptive caching based on timestep embedding similarity. It's model-agnostic
and supports multiple model types through extractor functions.
Key features:
- Zero changes to model code
- CFG-aware with separate states for positive/negative branches
- CFG-parallel compatible: properly detects branch identity across ranks
- Model-specific polynomial rescaling
- Auto-detection of model types
Attributes:
config: TeaCache configuration with thresholds and callbacks
rescale_func: Polynomial function for rescaling L1 distances
state_manager: Manages TeaCacheState across forward passes
extractor_fn: Model-specific function to extract modulated input
"""
_HOOK_NAME
=
"teacache"
def
__init__
(
self
,
config
:
TeaCacheConfig
):
"""
Initialize TeaCacheHook.
Args:
config: TeaCache configuration object.
"""
super
().
__init__
()
self
.
config
=
config
self
.
rescale_func
=
np
.
poly1d
(
config
.
coefficients
)
self
.
state_manager
=
StateManager
(
TeaCacheState
)
self
.
extractor_fn
=
None
self
.
_forward_cnt
=
0
def
initialize_hook
(
self
,
module
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
"""
Initialize hook with extractor from config transformer model type.
Args:
module: The module to initialize the hook for.
Returns:
The initialized module.
"""
# Get extractor function based on transformer_type from config
# transformer_type is the transformer class name (e.g., "QwenImageTransformer2DModel")
self
.
extractor_fn
=
get_extractor
(
self
.
config
.
transformer_type
)
# Set default context
self
.
state_manager
.
set_context
(
"teacache"
)
return
module
def
new_forward
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
"""
Generic forward handler that works for ANY model.
This method is completely model-agnostic. All model-specific logic
is encapsulated in the extractor function that returns a CacheContext.
The extractor does:
- Model-specific preprocessing
- Extraction of modulated input for cache decision
- Providing transformer execution callable
- Providing postprocessing callable
This hook does:
- CFG-aware state management
- Cache decision logic (generic)
- Residual caching and reuse
Args:
module: Transformer module (any architecture)
*args: Positional arguments for model forward
**kwargs: Keyword arguments for model forward
Returns:
Model output (format depends on model)
"""
# Get model-specific context from extractor
# The extractor encapsulates ALL model-specific logic
ctx
=
self
.
extractor_fn
(
module
,
*
args
,
**
kwargs
)
# ============================================================================
# GENERIC CACHING LOGIC (works for all models)
# ============================================================================
# Set context based on CFG branch for separate state tracking
# With CFG-parallel, each rank processes only one branch:
# - cfg_rank 0: positive branch
# - cfg_rank > 0: negative branch
# Without CFG-parallel, branches alternate within a single rank
if
getattr
(
module
,
"do_true_cfg"
,
False
):
cfg_parallel_size
=
get_classifier_free_guidance_world_size
()
if
cfg_parallel_size
>
1
:
cfg_rank
=
get_classifier_free_guidance_rank
()
cache_branch
=
"negative"
if
cfg_rank
>
0
else
"positive"
else
:
# No CFG-parallel: use forward counter to alternate branches
cache_branch
=
"negative"
if
self
.
_forward_cnt
%
2
==
1
else
"positive"
else
:
cache_branch
=
"positive"
context_name
=
f
"teacache_
{
cache_branch
}
"
self
.
state_manager
.
set_context
(
context_name
)
state
=
self
.
state_manager
.
get_state
()
# Decide whether to compute or cache based on modulated input similarity
should_compute
=
self
.
_should_compute_full_transformer
(
state
,
ctx
.
modulated_input
)
if
not
should_compute
and
state
.
previous_residual
is
not
None
:
# ============================================================================
# FAST PATH: Reuse cached residuals
# ============================================================================
ctx
.
hidden_states
=
ctx
.
hidden_states
+
state
.
previous_residual
if
state
.
previous_residual_encoder
is
not
None
and
ctx
.
encoder_hidden_states
is
not
None
:
ctx
.
encoder_hidden_states
=
ctx
.
encoder_hidden_states
+
state
.
previous_residual_encoder
output
=
ctx
.
hidden_states
else
:
# ============================================================================
# SLOW PATH: Full transformer computation
# ============================================================================
ori_hidden_states
=
ctx
.
hidden_states
.
clone
()
ori_encoder_hidden_states
=
(
ctx
.
encoder_hidden_states
.
clone
()
if
ctx
.
encoder_hidden_states
is
not
None
else
None
)
# Run transformer blocks using model-specific callable
outputs
=
ctx
.
run_transformer_blocks
()
# Update context with outputs
ctx
.
hidden_states
=
outputs
[
0
]
if
len
(
outputs
)
>
1
and
ctx
.
encoder_hidden_states
is
not
None
:
ctx
.
encoder_hidden_states
=
outputs
[
1
]
# Cache residuals for next timestep
state
.
previous_residual
=
(
ctx
.
hidden_states
-
ori_hidden_states
).
detach
()
if
ori_encoder_hidden_states
is
not
None
:
state
.
previous_residual_encoder
=
(
ctx
.
encoder_hidden_states
-
ori_encoder_hidden_states
).
detach
()
output
=
ctx
.
hidden_states
# Update state
state
.
previous_modulated_input
=
ctx
.
modulated_input
.
detach
()
state
.
cnt
+=
1
self
.
_forward_cnt
+=
1
# ============================================================================
# POSTPROCESSING (model-specific, via callable)
# ============================================================================
return
ctx
.
postprocess
(
output
)
def
_should_compute_full_transformer
(
self
,
state
:
TeaCacheState
,
modulated_inp
:
torch
.
Tensor
)
->
bool
:
"""
Determine whether to compute full transformer or reuse cached residual.
This implements the core TeaCache algorithm:
1. Always compute first timestep
2. For intermediate steps:
- Compute relative L1 distance between current and previous modulated inputs
- Apply polynomial rescaling with model-specific coefficients
- Accumulate rescaled distances
- Compare to threshold: below = cache, above = compute
Args:
state: Current TeaCacheState containing counters and cached values
modulated_inp: Modulated input extracted from first transformer block
Returns:
True to compute full transformer, False to reuse cached residual
"""
# First timestep: always compute
if
state
.
cnt
==
0
:
state
.
accumulated_rel_l1_distance
=
0.0
return
True
# Need previous input for comparison
if
state
.
previous_modulated_input
is
None
:
return
True
# Compute relative L1 distance between consecutive modulated inputs
rel_distance
=
(
(
(
modulated_inp
-
state
.
previous_modulated_input
).
abs
().
mean
()
/
(
state
.
previous_modulated_input
.
abs
().
mean
()
+
1e-8
)
)
.
cpu
()
.
item
()
)
# Apply model-specific polynomial rescaling
rescaled_distance
=
float
(
self
.
rescale_func
(
rel_distance
))
state
.
accumulated_rel_l1_distance
+=
abs
(
rescaled_distance
)
# Decision: below threshold = cache, above = compute
if
state
.
accumulated_rel_l1_distance
<
self
.
config
.
rel_l1_thresh
:
return
False
# Use cache
else
:
state
.
accumulated_rel_l1_distance
=
0.0
# Reset accumulator
return
True
# Compute
def
reset_state
(
self
,
module
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
"""
Reset all cached states for a new inference run.
Args:
module: The module to reset state for.
Returns:
The module with reset state.
"""
self
.
state_manager
.
reset
()
self
.
_forward_cnt
=
0
return
module
def
apply_teacache_hook
(
module
:
torch
.
nn
.
Module
,
config
:
TeaCacheConfig
)
->
None
:
"""
Apply TeaCache optimization to a transformer module.
This function registers a TeaCacheHook that completely intercepts the
module's forward pass, implementing adaptive caching without any changes
to the model code.
Args:
module: Transformer model to optimize (e.g., QwenImageTransformer2DModel)
config: TeaCacheConfig specifying caching parameters
Example:
>>> config = TeaCacheConfig(
... rel_l1_thresh=0.2,
... transformer_type="QwenImageTransformer2DModel"
... )
>>> apply_teacache_hook(transformer, config)
>>> # Transformer bound to the pipeline now uses TeaCache automatically,
... # no code changes needed!
"""
registry
=
HookRegistry
.
get_or_create
(
module
)
hook
=
TeaCacheHook
(
config
)
registry
.
register_hook
(
TeaCacheHook
.
_HOOK_NAME
,
hook
)
vllm_omni/diffusion/cache/teacache/state.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
TeaCache state management.
This module manages the state for TeaCache hooks across diffusion timesteps.
"""
import
torch
class
TeaCacheState
:
"""
State management for TeaCache hook.
Tracks caching state across diffusion timesteps, managing counters,
accumulated distances, and cached residuals for the TeaCache algorithm.
"""
def
__init__
(
self
):
"""Initialize empty TeaCache state."""
# Timestep tracking
self
.
cnt
=
0
# Caching state
self
.
accumulated_rel_l1_distance
=
0.0
self
.
previous_modulated_input
:
torch
.
Tensor
|
None
=
None
self
.
previous_residual
:
torch
.
Tensor
|
None
=
None
self
.
previous_residual_encoder
:
torch
.
Tensor
|
None
=
None
def
reset
(
self
)
->
None
:
"""Reset all state variables for a new inference run."""
self
.
cnt
=
0
self
.
accumulated_rel_l1_distance
=
0.0
self
.
previous_modulated_input
=
None
self
.
previous_residual
=
None
self
.
previous_residual_encoder
=
None
vllm_omni/diffusion/compile.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
regionally_compile
(
model
:
nn
.
Module
,
*
compile_args
:
Any
,
**
compile_kwargs
:
Any
)
->
nn
.
Module
:
"""
Apply regional compilation to a PyTorch model.
Args:
model: The PyTorch model instance to compile
*compile_args: Positional arguments forwarded to torch.compile
**compile_kwargs: Keyword arguments forwarded to torch.compile
Returns:
The same model instance (modified in-place)
"""
# Get the list of repeated blocks from the model
repeated_blocks
=
getattr
(
model
,
"_repeated_blocks"
,
None
)
if
not
repeated_blocks
:
logger
.
warning
(
"Regional compilation skipped because the model does not define `_repeated_blocks`."
)
return
model
# Check if we have modules with the specified class names
has_compiled_region
=
False
for
submod
in
model
.
modules
():
if
submod
.
__class__
.
__name__
in
repeated_blocks
:
# Compile this submodule
submod
.
compile
(
*
compile_args
,
**
compile_kwargs
)
has_compiled_region
=
True
if
not
has_compiled_region
:
logger
.
warning
(
f
"Regional compilation skipped because
{
repeated_blocks
}
classes are not found in the model."
)
return
model
vllm_omni/diffusion/data.py
0 → 100644
View file @
c1cacde6
# adapted from sglang and fastvideo
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
enum
import
os
import
random
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
Any
import
torch
from
pydantic
import
model_validator
from
typing_extensions
import
Self
from
vllm.config.utils
import
config
from
vllm.logger
import
init_logger
from
vllm_omni.diffusion.utils.network_utils
import
is_port_available
logger
=
init_logger
(
__name__
)
@
config
@
dataclass
class
DiffusionParallelConfig
:
"""Configuration for diffusion model distributed execution."""
pipeline_parallel_size
:
int
=
1
"""Number of pipeline parallel stages."""
data_parallel_size
:
int
=
1
"""Number of data parallel groups."""
tensor_parallel_size
:
int
=
1
"""Number of tensor parallel groups."""
sequence_parallel_size
:
int
|
None
=
None
"""Number of sequence parallel groups. sequence_parallel_size = ring_degree * ulysses_degree"""
ulysses_degree
:
int
=
1
"""Number of GPUs used for ulysses sequence parallelism."""
ring_degree
:
int
=
1
"""Number of GPUs used for ring sequence parallelism."""
cfg_parallel_size
:
int
=
1
"""Number of Classifier Free Guidance (CFG) parallel groups."""
@
model_validator
(
mode
=
"after"
)
def
_validate_parallel_config
(
self
)
->
Self
:
"""Validates the config relationships among the parallel strategies."""
assert
self
.
pipeline_parallel_size
>
0
,
"Pipeline parallel size must be > 0"
assert
self
.
data_parallel_size
>
0
,
"Data parallel size must be > 0"
assert
self
.
tensor_parallel_size
>
0
,
"Tensor parallel size must be > 0"
assert
self
.
sequence_parallel_size
>
0
,
"Sequence parallel size must be > 0"
assert
self
.
ulysses_degree
>
0
,
"Ulysses degree must be > 0"
assert
self
.
ring_degree
>
0
,
"Ring degree must be > 0"
assert
self
.
cfg_parallel_size
>
0
,
"CFG parallel size must be > 0"
assert
self
.
cfg_parallel_size
in
[
1
,
2
],
f
"CFG parallel size must be 1 or 2, but got
{
self
.
cfg_parallel_size
}
"
assert
self
.
sequence_parallel_size
==
self
.
ulysses_degree
*
self
.
ring_degree
,
(
"Sequence parallel size must be equal to the product of ulysses degree and ring degree,"
f
" but got
{
self
.
sequence_parallel_size
}
!=
{
self
.
ulysses_degree
}
*
{
self
.
ring_degree
}
"
)
return
self
def
__post_init__
(
self
)
->
None
:
if
self
.
sequence_parallel_size
is
None
:
self
.
sequence_parallel_size
=
self
.
ulysses_degree
*
self
.
ring_degree
self
.
world_size
=
(
self
.
pipeline_parallel_size
*
self
.
data_parallel_size
*
self
.
tensor_parallel_size
*
self
.
ulysses_degree
*
self
.
ring_degree
*
self
.
cfg_parallel_size
)
@
classmethod
def
from_dict
(
cls
,
data
:
dict
[
str
,
Any
])
->
"DiffusionParallelConfig"
:
"""
Create DiffusionParallelConfig from a dictionary.
Args:
data: Dictionary containing parallel configuration parameters
Returns:
DiffusionParallelConfig instance with parameters set from dict
"""
if
not
isinstance
(
data
,
dict
):
raise
TypeError
(
f
"Expected parallel config dict, got
{
type
(
data
)
!
r
}
"
)
return
cls
(
**
data
)
@
dataclass
class
TransformerConfig
:
"""Container for raw transformer configuration dictionaries."""
params
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
@
classmethod
def
from_dict
(
cls
,
data
:
dict
[
str
,
Any
])
->
"TransformerConfig"
:
if
not
isinstance
(
data
,
dict
):
raise
TypeError
(
f
"Expected transformer config dict, got
{
type
(
data
)
!
r
}
"
)
return
cls
(
params
=
dict
(
data
))
def
to_dict
(
self
)
->
dict
[
str
,
Any
]:
return
dict
(
self
.
params
)
def
get
(
self
,
key
:
str
,
default
:
Any
|
None
=
None
)
->
Any
:
return
self
.
params
.
get
(
key
,
default
)
def
__getattr__
(
self
,
item
:
str
)
->
Any
:
params
=
object
.
__getattribute__
(
self
,
"params"
)
try
:
return
params
[
item
]
except
KeyError
as
exc
:
raise
AttributeError
(
item
)
from
exc
@
dataclass
class
DiffusionCacheConfig
:
"""
Configuration for cache adapters (TeaCache, cache-dit, etc.).
This dataclass provides a unified interface for cache configuration parameters.
It can be initialized from a dictionary and accessed via attributes.
Common parameters:
- TeaCache: rel_l1_thresh, coefficients (optional)
- cache-dit: Fn_compute_blocks, Bn_compute_blocks, max_warmup_steps,
residual_diff_threshold, enable_taylorseer, taylorseer_order,
scm_steps_mask_policy, scm_steps_policy
Example:
>>> # From dict (user-facing API) - partial config uses defaults for missing keys
>>> config = DiffusionCacheConfig.from_dict({"rel_l1_thresh": 0.3})
>>> # Access via attribute
>>> print(config.rel_l1_thresh) # 0.3 (from dict)
>>> print(config.Fn_compute_blocks) # 8 (default)
>>> # Empty dict uses all defaults
>>> default_config = DiffusionCacheConfig.from_dict({})
>>> print(default_config.rel_l1_thresh) # 0.2 (default)
"""
# TeaCache parameters [tea_cache only]
# Default: 0.2 provides ~1.5x speedup with minimal quality loss (optimal balance)
rel_l1_thresh
:
float
=
0.2
coefficients
:
list
[
float
]
|
None
=
None
# Uses model-specific defaults if None
# cache-dit parameters [cache-dit only]
# Default: 1 forward compute block (optimized for single-transformer models)
# Use 1 as default instead of cache-dit's 8, optimized for single-transformer models
# This provides better performance while maintaining quality for most use cases
Fn_compute_blocks
:
int
=
1
# Default: 0 backward compute blocks (no fusion by default)
Bn_compute_blocks
:
int
=
0
# Default: 4 warmup steps (optimized for few-step distilled models like Z-Image with 8 steps)
# Use 4 as default warmup steps instead of 8 in cache-dit, making DBCache work
# for few-step distilled models (e.g., Z-Image with 8 steps)
max_warmup_steps
:
int
=
4
# Default: -1 (unlimited cached steps) - DBCache disables caching when previous cached steps exceed this value
# to prevent precision degradation. Set to -1 for unlimited caching (cache-dit default).
max_cached_steps
:
int
=
-
1
# Default: 0.24 residual difference threshold (higher for more aggressive caching)
# Use a relatively higher residual diff threshold (0.24) as default to allow more
# aggressive caching. This is safe because we have max_continuous_cached_steps limit.
# Without this limit, a lower threshold like 0.12 would be needed.
residual_diff_threshold
:
float
=
0.24
# Default: Limit consecutive cached steps to 3 to prevent precision degradation
# This allows us to use a higher residual_diff_threshold for more aggressive caching
max_continuous_cached_steps
:
int
=
3
# Default: Disable TaylorSeer (not suitable for few-step distilled models)
# TaylorSeer is not suitable for few-step distilled models, so we disable it by default.
# References:
# - From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
# - Forecast then Calibrate: Feature Caching as ODE for Efficient Diffusion Transformers
enable_taylorseer
:
bool
=
False
# Default: 1st order TaylorSeer polynomial
taylorseer_order
:
int
=
1
# Default: None SCM mask policy (disabled by default)
scm_steps_mask_policy
:
str
|
None
=
None
# Default: "dynamic" steps policy for adaptive caching
scm_steps_policy
:
str
=
"dynamic"
# Used by cache-dit for scm mask generation. If this value changes during inference,
# we will re-generate the scm mask and refresh the cache context.
num_inference_steps
:
int
|
None
=
None
# Additional parameters that may be passed but not explicitly defined
_extra_params
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
,
repr
=
False
)
@
classmethod
def
from_dict
(
cls
,
data
:
dict
[
str
,
Any
])
->
"DiffusionCacheConfig"
:
"""
Create DiffusionCacheConfig from a dictionary.
Args:
data: Dictionary containing cache configuration parameters
Returns:
DiffusionCacheConfig instance with parameters set from dict
"""
if
not
isinstance
(
data
,
dict
):
raise
TypeError
(
f
"Expected cache config dict, got
{
type
(
data
)
!
r
}
"
)
# Get all dataclass field names automatically
field_names
=
{
f
.
name
for
f
in
fields
(
cls
)}
# Extract parameters that match dataclass fields (excluding private fields)
known_params
=
{
k
:
v
for
k
,
v
in
data
.
items
()
if
k
in
field_names
and
not
k
.
startswith
(
"_"
)}
# Store extra parameters
extra_params
=
{
k
:
v
for
k
,
v
in
data
.
items
()
if
k
not
in
field_names
}
# Create instance with known params (missing ones will use defaults)
# Then update _extra_params after creation since it's a private field
instance
=
cls
(
**
known_params
,
_extra_params
=
extra_params
)
return
instance
def
__getattr__
(
self
,
item
:
str
)
->
Any
:
"""
Allow access to extra parameters via attribute access.
This enables accessing parameters that weren't explicitly defined
in the dataclass fields but were passed in the dict.
"""
if
item
==
"_extra_params"
or
item
.
startswith
(
"_"
):
return
object
.
__getattribute__
(
self
,
item
)
extra
=
object
.
__getattribute__
(
self
,
"_extra_params"
)
if
item
in
extra
:
return
extra
[
item
]
raise
AttributeError
(
f
"'
{
self
.
__class__
.
__name__
}
' object has no attribute '
{
item
}
'"
)
@
dataclass
class
OmniDiffusionConfig
:
# Model and path configuration (for convenience)
model
:
str
|
None
=
None
model_class_name
:
str
|
None
=
None
dtype
:
torch
.
dtype
=
torch
.
bfloat16
tf_model_config
:
TransformerConfig
=
field
(
default_factory
=
TransformerConfig
)
# Attention
attention_backend
:
str
|
None
=
None
# Running mode
# mode: ExecutionMode = ExecutionMode.INFERENCE
# Workload type
# workload_type: WorkloadType = WorkloadType.T2V
# Cache strategy (legacy)
cache_strategy
:
str
=
"none"
parallel_config
:
DiffusionParallelConfig
=
field
(
default_factory
=
DiffusionParallelConfig
)
# Cache backend configuration (NEW)
cache_backend
:
str
=
"none"
# "tea_cache", "deep_cache", etc.
cache_config
:
DiffusionCacheConfig
|
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
enable_cache_dit_summary
:
bool
=
False
# Distributed executor backend
distributed_executor_backend
:
str
=
"mp"
nccl_port
:
int
|
None
=
None
# HuggingFace specific parameters
trust_remote_code
:
bool
=
False
revision
:
str
|
None
=
None
num_gpus
:
int
|
None
=
None
hsdp_replicate_dim
:
int
=
1
hsdp_shard_dim
:
int
=
-
1
dist_timeout
:
int
|
None
=
None
# timeout for torch.distributed
# pipeline_config: PipelineConfig = field(default_factory=PipelineConfig, repr=False)
# LoRA parameters
lora_path
:
str
|
None
=
None
lora_scale
:
float
=
1.0
max_cpu_loras
:
int
|
None
=
None
output_type
:
str
=
"pil"
# CPU offload parameters
# When enabled, DiT and encoders swap GPU access (mutual exclusion):
# - Text encoders run on GPU while DiT is on CPU
# - DiT runs on GPU while encoders are on CPU
enable_cpu_offload
:
bool
=
False
# Layer-wise offloading (block-level offloading) parameters
enable_layerwise_offload
:
bool
=
False
# Number of transformer blocks ready for computation to keep on GPU
layerwise_num_gpu_layers
:
int
=
1
use_fsdp_inference
:
bool
=
False
pin_cpu_memory
:
bool
=
True
# Use pinned memory for faster transfers when offloading
# VAE memory optimization parameters
vae_use_slicing
:
bool
=
False
vae_use_tiling
:
bool
=
False
# STA (Sliding Tile Attention) parameters
mask_strategy_file_path
:
str
|
None
=
None
# STA_mode: STA_Mode = STA_Mode.STA_INFERENCE
skip_time_steps
:
int
=
15
# Compilation
enforce_eager
:
bool
=
False
# Enable sleep mode
enable_sleep_mode
:
bool
=
False
disable_autocast
:
bool
=
False
# VSA parameters
VSA_sparsity
:
float
=
0.0
# inference/validation sparsity
# V-MoBA parameters
moba_config_path
:
str
|
None
=
None
# moba_config: dict[str, Any] = field(default_factory=dict)
# Master port for distributed inference
# TODO: do not hard code
master_port
:
int
|
None
=
None
# http server endpoint config, would be ignored in local mode
host
:
str
|
None
=
None
port
:
int
|
None
=
None
scheduler_port
:
int
=
5555
# Stage verification
enable_stage_verification
:
bool
=
True
# Prompt text file for batch processing
prompt_file_path
:
str
|
None
=
None
# model paths for correct deallocation
model_paths
:
dict
[
str
,
str
]
=
field
(
default_factory
=
dict
)
model_loaded
:
dict
[
str
,
bool
]
=
field
(
default_factory
=
lambda
:
{
"transformer"
:
True
,
"vae"
:
True
,
}
)
override_transformer_cls_name
:
str
|
None
=
None
# # DMD parameters
# dmd_denoising_steps: List[int] | None = field(default=None)
# MoE parameters used by Wan2.2
boundary_ratio
:
float
|
None
=
None
# Scheduler flow_shift for Wan2.2 (12.0 for 480p, 5.0 for 720p)
flow_shift
:
float
|
None
=
None
# support multi images input
supports_multimodal_inputs
:
bool
=
False
# Logging
log_level
:
str
=
"info"
# Omni configuration (injected from stage config)
omni_kv_config
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
def
settle_port
(
self
,
port
:
int
,
port_inc
:
int
=
42
,
max_attempts
:
int
=
100
)
->
int
:
"""
Find an available port with retry logic.
Args:
port: Initial port to check
port_inc: Port increment for each attempt
max_attempts: Maximum number of attempts to find an available port
Returns:
An available port number
Raises:
RuntimeError: If no available port is found after max_attempts
"""
attempts
=
0
original_port
=
port
while
attempts
<
max_attempts
:
if
is_port_available
(
port
):
if
attempts
>
0
:
logger
.
info
(
f
"Port
{
original_port
}
was unavailable, using port
{
port
}
instead"
)
return
port
attempts
+=
1
if
port
<
60000
:
port
+=
port_inc
else
:
# Wrap around with randomization to avoid collision
port
=
5000
+
random
.
randint
(
0
,
1000
)
raise
RuntimeError
(
f
"Failed to find available port after
{
max_attempts
}
attempts (started from port
{
original_port
}
)"
)
def
__post_init__
(
self
):
# TODO: remove hard code
initial_master_port
=
(
self
.
master_port
or
30005
)
+
random
.
randint
(
0
,
100
)
self
.
master_port
=
self
.
settle_port
(
initial_master_port
,
37
)
# Convert parallel_config dict to DiffusionParallelConfig if needed
# This must be done before accessing parallel_config.world_size
if
isinstance
(
self
.
parallel_config
,
dict
):
self
.
parallel_config
=
DiffusionParallelConfig
.
from_dict
(
self
.
parallel_config
)
elif
not
isinstance
(
self
.
parallel_config
,
DiffusionParallelConfig
):
# If it's neither dict nor DiffusionParallelConfig, use default config
self
.
parallel_config
=
DiffusionParallelConfig
()
if
self
.
num_gpus
is
None
:
if
self
.
parallel_config
is
not
None
:
self
.
num_gpus
=
self
.
parallel_config
.
world_size
else
:
self
.
num_gpus
=
1
if
self
.
num_gpus
<
self
.
parallel_config
.
world_size
:
raise
ValueError
(
f
"num_gpus (
{
self
.
num_gpus
}
) < parallel_config.world_size (
{
self
.
parallel_config
.
world_size
}
)"
)
# Convert string dtype to torch.dtype if needed
if
isinstance
(
self
.
dtype
,
str
):
dtype_map
=
{
"auto"
:
torch
.
bfloat16
,
"bfloat16"
:
torch
.
bfloat16
,
"bf16"
:
torch
.
bfloat16
,
"float16"
:
torch
.
float16
,
"fp16"
:
torch
.
float16
,
"half"
:
torch
.
float16
,
"float32"
:
torch
.
float32
,
"fp32"
:
torch
.
float32
,
"float"
:
torch
.
float32
,
}
dtype_lower
=
self
.
dtype
.
lower
()
if
dtype_lower
in
dtype_map
:
self
.
dtype
=
dtype_map
[
dtype_lower
]
else
:
logger
.
warning
(
f
"Unknown dtype string '
{
self
.
dtype
}
', defaulting to bfloat16"
)
self
.
dtype
=
torch
.
bfloat16
# Convert cache_config dict to DiffusionCacheConfig if needed
if
isinstance
(
self
.
cache_config
,
dict
):
self
.
cache_config
=
DiffusionCacheConfig
.
from_dict
(
self
.
cache_config
)
elif
not
isinstance
(
self
.
cache_config
,
DiffusionCacheConfig
):
# If it's neither dict nor DiffusionCacheConfig, convert to empty config
self
.
cache_config
=
DiffusionCacheConfig
()
if
self
.
max_cpu_loras
is
None
:
self
.
max_cpu_loras
=
1
elif
self
.
max_cpu_loras
<
1
:
raise
ValueError
(
"max_cpu_loras must be >= 1 for diffusion LoRA"
)
def
update_multimodal_support
(
self
)
->
None
:
self
.
supports_multimodal_inputs
=
self
.
model_class_name
in
{
"QwenImageEditPlusPipeline"
}
@
classmethod
def
from_kwargs
(
cls
,
**
kwargs
:
Any
)
->
"OmniDiffusionConfig"
:
# Backwards-compatibility: older callers may use a diffusion-specific
# "static_lora_scale" kwarg. Normalize it to the canonical "lora_scale"
# before constructing the dataclass to avoid TypeError on unknown fields.
if
"static_lora_scale"
in
kwargs
:
if
"lora_scale"
not
in
kwargs
:
kwargs
[
"lora_scale"
]
=
kwargs
[
"static_lora_scale"
]
kwargs
.
pop
(
"static_lora_scale"
,
None
)
# Check environment variable as fallback for cache_backend
# Support both old DIFFUSION_CACHE_ADAPTER and new DIFFUSION_CACHE_BACKEND for backwards compatibility
if
"cache_backend"
not
in
kwargs
:
cache_backend
=
os
.
environ
.
get
(
"DIFFUSION_CACHE_BACKEND"
)
or
os
.
environ
.
get
(
"DIFFUSION_CACHE_ADAPTER"
)
kwargs
[
"cache_backend"
]
=
cache_backend
.
lower
()
if
cache_backend
else
"none"
# Filter kwargs to only include valid fields
valid_fields
=
{
f
.
name
for
f
in
fields
(
cls
)}
filtered_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
in
valid_fields
}
return
cls
(
**
filtered_kwargs
)
@
dataclass
class
DiffusionOutput
:
"""
Final output (after pipeline completion)
"""
output
:
torch
.
Tensor
|
None
=
None
trajectory_timesteps
:
list
[
torch
.
Tensor
]
|
None
=
None
trajectory_latents
:
torch
.
Tensor
|
None
=
None
trajectory_decoded
:
list
[
torch
.
Tensor
]
|
None
=
None
error
:
str
|
None
=
None
post_process_func
:
Callable
[...,
Any
]
|
None
=
None
# logged timings info, directly from Req.timings
# timings: Optional["RequestTimings"] = None
class
AttentionBackendEnum
(
enum
.
Enum
):
FA
=
enum
.
auto
()
SLIDING_TILE_ATTN
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
SAGE_ATTN
=
enum
.
auto
()
SAGE_ATTN_THREE
=
enum
.
auto
()
VIDEO_SPARSE_ATTN
=
enum
.
auto
()
VMOBA_ATTN
=
enum
.
auto
()
AITER
=
enum
.
auto
()
NO_ATTENTION
=
enum
.
auto
()
def
__str__
(
self
):
return
self
.
name
.
lower
()
# Special message broadcast via scheduler queues to signal worker shutdown.
SHUTDOWN_MESSAGE
=
{
"type"
:
"shutdown"
}
Prev
1
…
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment