Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
vllm-omni-das
Commits
c1cacde6
Commit
c1cacde6
authored
Mar 25, 2026
by
weishb
Browse files
vllm-omni_0.15.0.rc1+fix1 first commit
parent
35607782
Changes
306
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3210 additions
and
0 deletions
+3210
-0
vllm_omni/diffusion/attention/backends/ring_flash_attn.py
vllm_omni/diffusion/attention/backends/ring_flash_attn.py
+316
-0
vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py
vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py
+128
-0
vllm_omni/diffusion/attention/backends/sage_attn.py
vllm_omni/diffusion/attention/backends/sage_attn.py
+72
-0
vllm_omni/diffusion/attention/backends/sdpa.py
vllm_omni/diffusion/attention/backends/sdpa.py
+122
-0
vllm_omni/diffusion/attention/backends/utils/__init__.py
vllm_omni/diffusion/attention/backends/utils/__init__.py
+13
-0
vllm_omni/diffusion/attention/backends/utils/fa.py
vllm_omni/diffusion/attention/backends/utils/fa.py
+259
-0
vllm_omni/diffusion/attention/layer.py
vllm_omni/diffusion/attention/layer.py
+133
-0
vllm_omni/diffusion/attention/parallel/__init__.py
vllm_omni/diffusion/attention/parallel/__init__.py
+22
-0
vllm_omni/diffusion/attention/parallel/base.py
vllm_omni/diffusion/attention/parallel/base.py
+82
-0
vllm_omni/diffusion/attention/parallel/factory.py
vllm_omni/diffusion/attention/parallel/factory.py
+71
-0
vllm_omni/diffusion/attention/parallel/ring.py
vllm_omni/diffusion/attention/parallel/ring.py
+175
-0
vllm_omni/diffusion/attention/parallel/ulysses.py
vllm_omni/diffusion/attention/parallel/ulysses.py
+238
-0
vllm_omni/diffusion/attention/selector.py
vllm_omni/diffusion/attention/selector.py
+85
-0
vllm_omni/diffusion/cache/__init__.py
vllm_omni/diffusion/cache/__init__.py
+27
-0
vllm_omni/diffusion/cache/base.py
vllm_omni/diffusion/cache/base.py
+112
-0
vllm_omni/diffusion/cache/cache_dit_backend.py
vllm_omni/diffusion/cache/cache_dit_backend.py
+923
-0
vllm_omni/diffusion/cache/selector.py
vllm_omni/diffusion/cache/selector.py
+38
-0
vllm_omni/diffusion/cache/teacache/__init__.py
vllm_omni/diffusion/cache/teacache/__init__.py
+45
-0
vllm_omni/diffusion/cache/teacache/backend.py
vllm_omni/diffusion/cache/teacache/backend.py
+152
-0
vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
+197
-0
No files found.
Too many changes to show.
To preserve performance only
306 of 306+
files are displayed.
Plain diff
Email patch
vllm_omni/diffusion/attention/backends/ring_flash_attn.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Jiarui Fang.
# Adapted from https://github.com/feifeibear/long-context-attention
import
torch
from
vllm_omni.diffusion.attention.backends.ring.ring_selector
import
AttnType
,
select_flash_attn_impl
from
vllm_omni.diffusion.attention.backends.ring.ring_utils
import
update_out_and_lse
from
vllm_omni.diffusion.distributed.comm
import
RingComm
def
ring_flash_attn_forward
(
process_group
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
softmax_scale
,
dropout_p
=
0
,
causal
=
True
,
window_size
=
(
-
1
,
-
1
),
softcap
=
0.0
,
alibi_slopes
=
None
,
deterministic
=
False
,
attn_type
:
AttnType
=
AttnType
.
FA
,
attn_processor
=
None
,
joint_tensor_key
=
None
,
joint_tensor_value
=
None
,
joint_strategy
=
"front"
,
):
# Validate causal + joint_strategy combination
# When causal=True and joint_strategy="rear", the causal mask would incorrectly
# prevent local query tokens from attending to joint key tokens (which are
# concatenated at the end). This breaks the semantics where joint tokens
# (e.g., text conditioning) should be visible to all local tokens.
if
causal
and
joint_tensor_key
is
not
None
and
joint_strategy
==
"rear"
:
raise
ValueError
(
"joint_strategy='rear' is not compatible with causal=True in Ring Attention. "
"When using causal attention with joint tokens, use joint_strategy='front' "
"to ensure joint tokens act as a visible prefix for all local tokens. "
"With 'rear' strategy, the causal mask would incorrectly block local tokens "
"from seeing the joint tokens."
)
comm
=
RingComm
(
process_group
)
out
=
None
lse
=
None
next_k
,
next_v
=
None
,
None
# Check and adjust q, k, v to be contiguous
if
not
q
.
is_contiguous
():
q
=
q
.
contiguous
()
if
not
k
.
is_contiguous
():
k
=
k
.
contiguous
()
if
not
v
.
is_contiguous
():
v
=
v
.
contiguous
()
for
step
in
range
(
comm
.
world_size
):
if
step
+
1
!=
comm
.
world_size
:
next_k
:
torch
.
Tensor
next_v
:
torch
.
Tensor
next_k
=
comm
.
send_recv
(
k
)
next_v
=
comm
.
send_recv
(
v
)
comm
.
commit
()
if
not
causal
or
step
<=
comm
.
rank
:
step_k
=
k
step_v
=
v
if
step
==
0
and
joint_tensor_key
is
not
None
:
if
joint_strategy
==
"front"
:
step_k
=
torch
.
cat
([
joint_tensor_key
,
step_k
],
dim
=
1
)
step_v
=
torch
.
cat
([
joint_tensor_value
,
step_v
],
dim
=
1
)
else
:
step_k
=
torch
.
cat
([
step_k
,
joint_tensor_key
],
dim
=
1
)
step_v
=
torch
.
cat
([
step_v
,
joint_tensor_value
],
dim
=
1
)
fn
=
select_flash_attn_impl
(
attn_type
,
stage
=
"fwd-only"
,
attn_processor
=
attn_processor
)
block_out
,
block_lse
=
fn
(
q
,
step_k
,
step_v
,
dropout_p
=
dropout_p
,
softmax_scale
=
softmax_scale
,
causal
=
causal
and
step
==
0
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
True
and
dropout_p
>
0
,
)
# Ensure block_out is contiguous if needed, though usually it is from FA
if
attn_type
==
AttnType
.
SPARSE_SAGE
:
out
,
lse
=
block_out
,
block_lse
else
:
out
,
lse
=
update_out_and_lse
(
out
,
lse
,
block_out
,
block_lse
)
if
step
+
1
!=
comm
.
world_size
:
comm
.
wait
()
k
=
next_k
v
=
next_v
out
=
out
.
to
(
q
.
dtype
)
if
attn_type
!=
AttnType
.
SPARSE_SAGE
:
lse
=
lse
.
squeeze
(
dim
=-
1
).
transpose
(
1
,
2
)
return
out
,
lse
class
RingFlashAttnFunc
(
torch
.
autograd
.
Function
):
"""Ring Flash Attention autograd function (inference only, no backward)."""
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_softmax
,
group
,
attn_type
,
attn_processor
,
joint_tensor_key
=
None
,
joint_tensor_value
=
None
,
joint_strategy
=
"front"
,
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
assert
alibi_slopes
is
None
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
v
=
v
.
contiguous
()
out
,
softmax_lse
=
ring_flash_attn_forward
(
group
,
q
,
k
,
v
,
softmax_scale
=
softmax_scale
,
dropout_p
=
dropout_p
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
deterministic
=
False
,
attn_type
=
attn_type
,
attn_processor
=
attn_processor
,
joint_tensor_key
=
joint_tensor_key
,
joint_tensor_value
=
joint_tensor_value
,
joint_strategy
=
joint_strategy
,
)
return
out
if
not
return_softmax
else
(
out
,
softmax_lse
,
None
)
def
ring_flash_attn_qkvpacked_func
(
qkv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
softcap
=
0.0
,
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
group
=
None
,
attn_type
:
AttnType
=
AttnType
.
FA
,
):
return
RingFlashAttnFunc
.
apply
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
group
,
attn_type
,
None
,
# attn_processor
None
,
# joint_tensor_key
None
,
# joint_tensor_value
"front"
,
# joint_strategy
)
def
ring_flash_attn_kvpacked_func
(
q
,
kv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
softcap
=
0.0
,
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
group
=
None
,
attn_type
:
AttnType
=
AttnType
.
FA
,
):
return
RingFlashAttnFunc
.
apply
(
q
,
kv
[:,
:,
0
],
kv
[:,
:,
1
],
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
group
,
attn_type
,
None
,
# attn_processor
None
,
# joint_tensor_key
None
,
# joint_tensor_value
"front"
,
# joint_strategy
)
def
ring_flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
softcap
=
0.0
,
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
group
=
None
,
attn_type
:
AttnType
=
AttnType
.
FA
,
attn_processor
=
None
,
joint_tensor_key
=
None
,
joint_tensor_value
=
None
,
joint_strategy
=
"front"
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
"""Ring Attention forward pass using Flash Attention backend.
Implements Ring Attention with sequence parallelism using a ring-based P2P
communication pattern. The sequence dimension is sharded across devices, and
Key/Value blocks are circulated through the ring to accumulate attention results.
Args:
q (torch.Tensor): Query tensor of shape (batch, seq_len, num_heads, head_dim).
Sequence dimension is sharded across the ring group.
k (torch.Tensor): Key tensor of shape (batch, seq_len, num_heads, head_dim).
Sequence dimension is sharded across the ring group.
v (torch.Tensor): Value tensor of shape (batch, seq_len, num_heads, head_dim).
Sequence dimension is sharded across the ring group.
dropout_p (float): Dropout probability. Defaults to 0.0.
softmax_scale (float | None): Scaling factor for softmax.
If None, computed as head_dim^(-0.5).
causal (bool): Whether to apply causal masking. Defaults to False.
window_size (tuple[int, int]): Sliding window size for attention.
(-1, -1) means no windowing.
softcap (float): Soft capping value for attention logits. Defaults to 0.0.
alibi_slopes (torch.Tensor | None): ALiBi slopes for positional bias.
Not supported.
deterministic (bool): Whether to use deterministic algorithms.
Defaults to False.
return_attn_probs (bool): If True, returns (out, softmax_lse, None).
Defaults to False.
group (ProcessGroup | None): Process group for ring communication.
Defaults to None.
attn_type (AttnType): Flash Attention implementation type
(AttnType.FA, AttnType.FA3, etc.).
attn_processor (Callable | None): Custom attention processor for sparse
attention. Defaults to None.
joint_tensor_key (torch.Tensor | None): Additional key tensor for joint
attention (e.g., text + image). Concatenated only at step=0.
Defaults to None.
joint_tensor_value (torch.Tensor | None): Additional value tensor for
joint attention (e.g., text + image). Concatenated only at step=0.
Defaults to None.
joint_strategy (str): Concatenation strategy ("front" or "back").
Defaults to "front".
Returns:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, None]]:
- If return_attn_probs is False: Output tensor (batch, seq_len, num_heads, head_dim).
- If return_attn_probs is True: A tuple (out, softmax_lse, None).
"""
return
RingFlashAttnFunc
.
apply
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
window_size
,
softcap
,
alibi_slopes
,
deterministic
,
return_attn_probs
,
group
,
attn_type
,
attn_processor
,
joint_tensor_key
,
joint_tensor_value
,
joint_strategy
,
)
vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Jiarui Fang.
# Adapted from https://github.com/feifeibear/long-context-attention
# adapted from https://github.com/huggingface/picotron/blob/main/picotron/context_parallel/context_parallel.py
# Copyright 2024 The HuggingFace Inc. team and Jiarui Fang.
import
torch
from
vllm.logger
import
init_logger
from
vllm_omni.diffusion.attention.backends.ring.ring_kernels
import
pytorch_attn_forward
from
vllm_omni.diffusion.attention.backends.ring.ring_utils
import
update_out_and_lse
from
vllm_omni.diffusion.distributed.comm
import
RingComm
logger
=
init_logger
(
__name__
)
def
ring_pytorch_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
softcap
=
0.0
,
alibi_slopes
=
None
,
deterministic
=
False
,
return_attn_probs
=
False
,
group
=
None
,
op_type
=
"efficient"
,
joint_tensor_key
=
None
,
joint_tensor_value
=
None
,
joint_strategy
=
"front"
,
):
return
RingAttentionFunc
.
apply
(
group
,
q
,
k
,
v
,
softmax_scale
,
causal
,
op_type
,
joint_tensor_key
,
joint_tensor_value
,
joint_strategy
,
)
class
RingAttentionFunc
(
torch
.
autograd
.
Function
):
"""Ring Attention autograd function using PyTorch SDPA (inference only, no backward)."""
@
staticmethod
def
forward
(
ctx
,
group
,
q
,
k
,
v
,
sm_scale
,
is_causal
,
op_type
,
joint_tensor_key
=
None
,
joint_tensor_value
=
None
,
joint_strategy
=
"front"
,
):
# Validate causal + joint_strategy combination
# When causal=True and joint_strategy="rear", the causal mask would incorrectly
# prevent local query tokens from attending to joint key tokens (which are
# concatenated at the end). This breaks the semantics where joint tokens
# (e.g., text conditioning) should be visible to all local tokens.
if
is_causal
and
joint_tensor_key
is
not
None
and
joint_strategy
==
"rear"
:
raise
ValueError
(
"joint_strategy='rear' is not compatible with causal=True in Ring Attention. "
"When using causal attention with joint tokens, use joint_strategy='front' "
"to ensure joint tokens act as a visible prefix for all local tokens. "
"With 'rear' strategy, the causal mask would incorrectly block local tokens "
"from seeing the joint tokens."
)
comm
=
RingComm
(
group
)
# Ensure tensors are contiguous for P2P communication
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
v
=
v
.
contiguous
()
out
,
lse
=
None
,
None
next_k
,
next_v
=
None
,
None
if
sm_scale
is
None
:
sm_scale
=
q
.
shape
[
-
1
]
**
-
0.5
for
step
in
range
(
comm
.
world_size
):
if
step
+
1
!=
comm
.
world_size
:
next_k
=
comm
.
send_recv
(
k
)
next_v
=
comm
.
send_recv
(
v
)
comm
.
commit
()
if
not
is_causal
or
step
<=
comm
.
rank
:
step_k
=
k
step_v
=
v
if
step
==
0
and
joint_tensor_key
is
not
None
:
if
joint_strategy
==
"front"
:
step_k
=
torch
.
cat
([
joint_tensor_key
,
step_k
],
dim
=
1
)
step_v
=
torch
.
cat
([
joint_tensor_value
,
step_v
],
dim
=
1
)
else
:
step_k
=
torch
.
cat
([
step_k
,
joint_tensor_key
],
dim
=
1
)
step_v
=
torch
.
cat
([
step_v
,
joint_tensor_value
],
dim
=
1
)
block_out
,
block_lse
=
pytorch_attn_forward
(
q
,
step_k
,
step_v
,
softmax_scale
=
sm_scale
,
causal
=
is_causal
and
step
==
0
,
op_type
=
op_type
,
)
out
,
lse
=
update_out_and_lse
(
out
,
lse
,
block_out
,
block_lse
)
if
step
+
1
!=
comm
.
world_size
:
comm
.
wait
()
k
=
next_k
v
=
next_v
out
=
out
.
to
(
q
.
dtype
)
return
out
vllm_omni/diffusion/attention/backends/sage_attn.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.logger
import
init_logger
from
vllm_omni.diffusion.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
)
logger
=
init_logger
(
__name__
)
try
:
from
sageattention
import
sageattn
except
ImportError
:
logger
.
warning
(
"SageAttentionBackend is not available. You may install sage-attention"
" by pip install git+https://github.com/thu-ml/SageAttention.git"
)
raise
ImportError
# TODO add sage3 attention backend
class
SageAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
@
staticmethod
def
get_supported_head_sizes
()
->
list
[
int
]:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
@
staticmethod
def
get_name
()
->
str
:
return
"SAGE_ATTN"
@
staticmethod
def
get_impl_cls
()
->
type
[
"SageAttentionImpl"
]:
return
SageAttentionImpl
class
SageAttentionImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
softmax_scale
:
float
,
causal
:
bool
=
False
,
num_kv_heads
:
int
|
None
=
None
,
prefix
:
str
=
""
,
**
extra_impl_args
,
)
->
None
:
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
def
forward_cuda
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
=
None
,
)
->
torch
.
Tensor
:
output
=
sageattn
(
query
,
key
,
value
,
tensor_layout
=
"NHD"
,
is_causal
=
self
.
causal
,
sm_scale
=
self
.
softmax_scale
,
)
return
output
vllm_omni/diffusion/attention/backends/sdpa.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.logger
import
init_logger
from
vllm_omni.diffusion.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
)
logger
=
init_logger
(
__name__
)
def
_maybe_reshape_attn_mask
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
attn_mask
:
torch
.
Tensor
|
None
=
None
):
"""
Reshape Attention Mask
[batch_size, seq_len_k] -> [batch_size, 1, seq_len_q, seq_len_k]
"""
# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
if
attn_mask
is
not
None
and
torch
.
all
(
attn_mask
!=
0
):
attn_mask
=
None
# Reshape Attention Mask
# [batch_size, seq_len_k] -> [batch_size, 1, seq_len_q, seq_len_k]
if
(
attn_mask
is
not
None
and
attn_mask
.
ndim
==
2
and
attn_mask
.
shape
[
0
]
==
query
.
shape
[
0
]
and
attn_mask
.
shape
[
1
]
==
key
.
shape
[
1
]
):
B
,
Sq
,
Skv
=
attn_mask
.
shape
[
0
],
query
.
shape
[
1
],
key
.
shape
[
1
]
attn_mask
=
attn_mask
.
to
(
torch
.
bool
)
attn_mask
=
attn_mask
.
unsqueeze
(
1
).
expand
(
B
,
Sq
,
Skv
).
unsqueeze
(
1
).
contiguous
()
return
attn_mask
class
SDPABackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
@
classmethod
def
supports_attention_mask
(
cls
)
->
bool
:
return
True
@
staticmethod
def
get_supported_head_sizes
()
->
list
[
int
]:
return
[
x
for
x
in
range
(
1024
)]
# todo
@
staticmethod
def
get_name
()
->
str
:
return
"SDPA"
@
staticmethod
def
get_impl_cls
()
->
type
[
"SDPAImpl"
]:
return
SDPAImpl
class
SDPAImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
softmax_scale
:
float
,
causal
:
bool
=
False
,
num_kv_heads
:
int
|
None
=
None
,
prefix
:
str
=
""
,
**
extra_impl_args
,
)
->
None
:
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
def
forward_cuda
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
|
None
=
None
,
)
->
torch
.
Tensor
:
query
,
key
,
value
=
(
x
.
permute
(
0
,
2
,
1
,
3
)
for
x
in
(
query
,
key
,
value
))
attention_mask
=
attn_metadata
.
attn_mask
if
attn_metadata
else
None
output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
query
,
key
,
value
,
attn_mask
=
attention_mask
,
dropout_p
=
0.0
,
is_causal
=
self
.
causal
,
scale
=
self
.
softmax_scale
,
)
out
=
output
.
permute
(
0
,
2
,
1
,
3
)
return
out
def
forward_xpu
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
|
None
=
None
,
)
->
torch
.
Tensor
:
return
self
.
forward_cuda
(
query
,
key
,
value
,
attn_metadata
)
def
forward_hip
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
|
None
=
None
,
)
->
torch
.
Tensor
:
return
self
.
forward_cuda
(
query
,
key
,
value
,
attn_metadata
)
def
forward_npu
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
|
None
=
None
,
)
->
torch
.
Tensor
:
if
attn_metadata
:
attention_mask
=
_maybe_reshape_attn_mask
(
query
,
key
,
attn_metadata
.
attn_mask
)
setattr
(
attn_metadata
,
"attn_mask"
,
attention_mask
)
return
self
.
forward_cuda
(
query
,
key
,
value
,
attn_metadata
)
vllm_omni/diffusion/attention/backends/utils/__init__.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Utils for attention backends.
"""
from
vllm_omni.diffusion.attention.backends.utils.fa
import
_pad_input
,
_unpad_input
,
_upad_input
__all__
=
[
"_pad_input"
,
"_unpad_input"
,
"_upad_input"
,
]
vllm_omni/diffusion/attention/backends/utils/fa.py
0 → 100644
View file @
c1cacde6
# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py
import
torch
import
torch.nn.functional
as
F
from
vllm_omni.platforms
import
current_omni_platform
# Flash Attention function detection with fallback chain
flash_attn_func
=
None
flash_attn_varlen_func
=
None
if
current_omni_platform
.
is_rocm
():
# ROCm: try Aiter first
try
:
from
vllm._aiter_ops
import
is_aiter_found_and_supported
if
is_aiter_found_and_supported
():
from
aiter
import
flash_attn_func
,
flash_attn_varlen_func
# noqa: F401
except
(
ImportError
,
ModuleNotFoundError
):
pass
else
:
# CUDA: try FA3 -> FA2 fallback chain
# Try FA3 from fa3-fwd PyPI package
try
:
from
fa3_fwd_interface
import
flash_attn_func
,
flash_attn_varlen_func
# noqa: F401
except
(
ImportError
,
ModuleNotFoundError
):
pass
# Fallback: Try FA3 from flash-attention source build
if
flash_attn_func
is
None
:
try
:
from
flash_attn_interface
import
flash_attn_func
,
flash_attn_varlen_func
# noqa: F401
except
(
ImportError
,
ModuleNotFoundError
):
pass
# Fallback: Try FA2 from flash-attn package (try multiple import paths)
if
flash_attn_func
is
None
:
try
:
from
flash_attn
import
flash_attn_func
,
flash_attn_varlen_func
# noqa: F401
except
(
ImportError
,
ModuleNotFoundError
):
pass
if
flash_attn_func
is
None
:
try
:
from
flash_attn.flash_attn_interface
import
(
# noqa: F401
flash_attn_func
,
flash_attn_varlen_func
,
)
except
(
ImportError
,
ModuleNotFoundError
):
pass
# If no FA backend available, SDPA backend will be selected at the platform level
# flash_attn_func and flash_attn_varlen_func will be None
HAS_FLASH_ATTN
=
flash_attn_func
is
not
None
def
_index_first_axis
(
tensor
,
indices
):
"""
A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
after flattening the first two dimensions of the tensor. This is functionally equivalent to
FA2's `index_first_axis` and replaces the need to import it.
"""
# The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
# two dimensions to get (total_tokens, ...) before indexing.
reshaped_tensor
=
tensor
.
reshape
(
-
1
,
*
tensor
.
shape
[
2
:])
return
reshaped_tensor
[
indices
]
def
_unpad_input
(
hidden_states
,
attention_mask
,
unused_mask
=
None
):
"""
unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
"""
all_masks
=
(
attention_mask
+
unused_mask
)
if
unused_mask
is
not
None
else
attention_mask
seqlens_in_batch
=
all_masks
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
used_seqlens_in_batch
=
attention_mask
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
indices
=
torch
.
nonzero
(
all_masks
.
flatten
(),
as_tuple
=
False
).
flatten
()
max_seqlen_in_batch
=
seqlens_in_batch
.
max
().
item
()
cu_seqlens
=
F
.
pad
(
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
))
return
(
_index_first_axis
(
hidden_states
,
indices
),
indices
,
cu_seqlens
,
max_seqlen_in_batch
,
used_seqlens_in_batch
,
)
def
_pad_input
(
hidden_states
,
indices
,
batch
,
seqlen
):
"""
pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
dim
=
hidden_states
.
shape
[
1
:]
output
=
torch
.
zeros
((
batch
*
seqlen
),
*
dim
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
output
[
indices
]
=
hidden_states
return
output
.
view
(
batch
,
seqlen
,
*
dim
)
def
_get_unpad_data
(
attention_mask
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
int
]:
"""
Retrieves indexing data required to repad unpadded (ragged) tensors.
Arguments:
attention_mask (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
indices (`torch.Tensor`):
The indices of non-masked tokens from the flattened input sequence.
cu_seqlens (`torch.Tensor`):
The cumulative sequence lengths, used to index into ragged (unpadded) tensors.
`cu_seqlens` shape is (batch_size + 1,).
max_seqlen_in_batch (`int`):
Maximum sequence length in batch.
"""
seqlens_in_batch
=
attention_mask
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
indices
=
torch
.
nonzero
(
attention_mask
.
flatten
(),
as_tuple
=
False
).
flatten
()
# NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile,
# this might cause a graph break
max_seqlen_in_batch
=
seqlens_in_batch
.
max
().
item
()
cu_seqlens
=
F
.
pad
(
torch
.
cumsum
(
seqlens_in_batch
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
))
return
(
indices
,
cu_seqlens
,
max_seqlen_in_batch
,
)
def
_upad_input
(
query_layer
:
torch
.
Tensor
,
key_layer
:
torch
.
Tensor
,
value_layer
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
query_length
:
int
,
unpad_input_func
,
):
"""
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong
to different batches. This function is used instead of `flash_attn.bert_padding.unpad_input` in
order to avoid the recomputation of the same intermediary tensors for query, key, value tensors.
Arguments:
query_layer (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key_layer (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value_layer (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
attention_mask (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
query_length (`int`):
Target length.
unpad_input_func:
The function to use for unpadding the input tensors.
Return:
query_layer (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key_layer (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value_layer (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into
ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
`max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
if
torch
.
compiler
.
is_compiling
():
# allow PyTorch compiler to include operations that return scalar values (like .item()
torch
.
_dynamo
.
config
.
capture_scalar_outputs
=
True
indices_k
,
cu_seqlens_k
,
max_seqlen_in_batch_k
=
_get_unpad_data
(
attention_mask
)
# With static caches, the k/v states may be larger than the mask ->
# we need to slice them to avoid generating garbage
# It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores
if
key_layer
.
shape
[
1
]
>
(
seq_len
:
=
attention_mask
.
shape
[
-
1
]):
key_layer
,
value_layer
=
key_layer
[:,
:
seq_len
,
:,
:],
value_layer
[:,
:
seq_len
,
:,
:]
batch_size
,
kv_seq_len
,
num_key_value_heads
,
head_dim
=
key_layer
.
shape
key_layer
=
_index_first_axis
(
key_layer
,
indices_k
)
value_layer
=
_index_first_axis
(
value_layer
,
indices_k
)
if
query_length
==
kv_seq_len
:
query_layer
=
_index_first_axis
(
query_layer
,
indices_k
)
cu_seqlens_q
=
cu_seqlens_k
max_seqlen_in_batch_q
=
max_seqlen_in_batch_k
indices_q
=
indices_k
elif
query_length
==
1
:
max_seqlen_in_batch_q
=
1
cu_seqlens_q
=
torch
.
arange
(
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
query_layer
.
device
)
# There is a memcpy here, that is very bad.
indices_q
=
cu_seqlens_q
[:
-
1
]
query_layer
=
query_layer
.
squeeze
(
1
)
else
:
# The -q_len: slice assumes left padding.
attention_mask
=
attention_mask
[:,
-
query_length
:]
query_layer
,
indices_q
,
cu_seqlens_q
,
max_seqlen_in_batch_q
,
*
_
=
unpad_input_func
(
query_layer
,
attention_mask
)
return
(
query_layer
,
key_layer
,
value_layer
,
indices_q
,
(
cu_seqlens_q
,
cu_seqlens_k
),
(
max_seqlen_in_batch_q
,
max_seqlen_in_batch_k
),
)
def
_is_packed_sequence
(
position_ids
,
batch_size
):
"""
Check the position ids whether packed sequences are indicated or not
1. Position ids exist
2. Flattened sequences only are supported
3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e.
we have multiple increasing sequences
"""
if
position_ids
is
None
:
return
False
increasing_position_sequences
=
torch
.
arange
(
position_ids
.
shape
[
1
],
device
=
position_ids
.
device
)
+
position_ids
.
min
()
return
batch_size
==
1
and
(
increasing_position_sequences
-
position_ids
).
abs
().
sum
().
bool
()
vllm_omni/diffusion/attention/layer.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) Microsoft Corporation and Jiarui Fang
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team & Jiarui Fang
# Adapted from
# https://github.com/feifeibear/long-context-attention/blob/main/yunchang/attention/layer.py
import
torch
import
torch.nn
as
nn
from
vllm.logger
import
init_logger
from
vllm_omni.diffusion.attention.backends.abstract
import
AttentionMetadata
from
vllm_omni.diffusion.attention.backends.sdpa
import
SDPABackend
from
vllm_omni.diffusion.attention.parallel
import
build_parallel_attention_strategy
from
vllm_omni.diffusion.attention.parallel.ring
import
RingParallelAttention
from
vllm_omni.diffusion.attention.selector
import
get_attn_backend
from
vllm_omni.diffusion.distributed.parallel_state
import
get_sp_group
from
vllm_omni.diffusion.forward_context
import
get_forward_context
logger
=
init_logger
(
__name__
)
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
causal
:
bool
,
softmax_scale
:
float
,
num_kv_heads
:
int
|
None
=
None
,
prefix
:
str
=
""
,
# ulysses attention
scatter_idx
:
int
=
2
,
gather_idx
:
int
=
1
,
use_sync
:
bool
=
False
,
):
super
().
__init__
()
self
.
attn_backend
=
get_attn_backend
(
-
1
)
self
.
attn_impl_cls
=
self
.
attn_backend
.
get_impl_cls
()
self
.
attention
=
self
.
attn_impl_cls
(
num_heads
=
num_heads
,
head_size
=
head_size
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
num_kv_heads
=
num_kv_heads
,
)
# Instantiate fallback backend for float32 support
self
.
sdpa_fallback
=
SDPABackend
.
get_impl_cls
()(
num_heads
=
num_heads
,
head_size
=
head_size
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
num_kv_heads
=
num_kv_heads
,
)
self
.
backend_pref
=
None
self
.
softmax_scale
=
softmax_scale
self
.
scatter_idx
=
scatter_idx
self
.
gather_idx
=
gather_idx
self
.
use_sync
=
use_sync
self
.
causal
=
causal
self
.
use_ring
=
False
self
.
ring_pg
=
None
self
.
ring_runner
=
None
try
:
config
=
get_forward_context
().
omni_diffusion_config
self
.
backend_pref
=
config
.
attention_backend
if
config
.
parallel_config
.
ring_degree
>
1
:
self
.
use_ring
=
True
try
:
sp_group
=
get_sp_group
()
self
.
ring_pg
=
sp_group
.
ring_group
self
.
ring_runner
=
RingParallelAttention
(
sp_group
)
except
Exception
:
self
.
use_ring
=
False
self
.
ring_runner
=
None
except
Exception
:
self
.
use_ring
=
False
self
.
ring_runner
=
None
self
.
parallel_strategy
=
build_parallel_attention_strategy
(
scatter_idx
=
scatter_idx
,
gather_idx
=
gather_idx
,
use_sync
=
use_sync
,
)
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
=
None
,
)
->
torch
.
Tensor
:
# 1. Prepare inputs (Communication / Resharding)
# For Ulysses: AllToAll Q/K/V; Slicing joint_q/k/v
# For Ring: Concat joint_q
query
,
key
,
value
,
attn_metadata
,
ctx
=
self
.
parallel_strategy
.
pre_attention
(
query
,
key
,
value
,
attn_metadata
)
# 2. Kernel Execution (Computation)
if
self
.
use_ring
:
out
=
self
.
_run_ring_attention
(
query
,
key
,
value
,
attn_metadata
)
else
:
out
=
self
.
_run_local_attention
(
query
,
key
,
value
,
attn_metadata
)
# 3. Post-processing (Reverse Communication)
# For Ulysses: AllToAll Output, and AllGather Joint Output
out
=
self
.
parallel_strategy
.
post_attention
(
out
,
ctx
)
return
out
def
_run_local_attention
(
self
,
query
,
key
,
value
,
attn_metadata
):
if
query
.
dtype
==
torch
.
float32
:
logger
.
warning_once
(
f
"Only SDPA supports float32. Overriding user config
{
type
(
self
.
attention
)
}
"
f
"attention_backend='
{
self
.
backend_pref
}
' to 'sdpa' for dtype=
{
query
.
dtype
}
."
)
return
self
.
sdpa_fallback
.
forward
(
query
,
key
,
value
,
attn_metadata
)
# Fallback to standard attention
return
self
.
attention
.
forward
(
query
,
key
,
value
,
attn_metadata
)
def
_run_ring_attention
(
self
,
query
,
key
,
value
,
attn_metadata
):
# Delegate to RingParallelAttention strategy if available
if
self
.
ring_runner
is
not
None
:
return
self
.
ring_runner
.
run_attention
(
query
,
key
,
value
,
attn_metadata
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
)
raise
RuntimeError
(
"Ring attention is enabled but strategy is not RingParallelAttention"
)
vllm_omni/diffusion/attention/parallel/__init__.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Parallel attention strategies.
This package provides **communication / resharding strategies** for attention,
orthogonal to the **attention kernel backend** (SDPA/Flash/Sage).
The goal is to keep `vllm_omni.diffusion.attention.layer.Attention` small and
extensible: adding a new parallelism method should not require editing the core
Attention module, only adding a new strategy and selecting it in the factory.
"""
from
.base
import
NoParallelAttention
,
ParallelAttentionContext
,
ParallelAttentionStrategy
from
.factory
import
build_parallel_attention_strategy
__all__
=
[
"ParallelAttentionStrategy"
,
"ParallelAttentionContext"
,
"NoParallelAttention"
,
"build_parallel_attention_strategy"
,
]
vllm_omni/diffusion/attention/parallel/base.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
Protocol
import
torch
from
vllm_omni.diffusion.attention.backends.abstract
import
AttentionMetadata
@
dataclass
(
frozen
=
True
,
slots
=
True
)
class
ParallelAttentionContext
:
"""Opaque per-forward context returned by a parallel strategy.
Strategies may stash whatever they need here to finish post-processing after
the attention kernel runs (e.g. reverse resharding, slicing metadata, etc.).
"""
name
:
str
class
ParallelAttentionStrategy
(
Protocol
):
"""Pluggable strategy for parallel attention communication/resharding.
This is intentionally orthogonal to the attention *kernel* backend.
The kernel backend implements `AttentionImpl.forward()` for a given device,
while the parallel strategy implements how Q/K/V and outputs are sharded /
communicated across ranks.
"""
@
property
def
enabled
(
self
)
->
bool
:
...
@
property
def
name
(
self
)
->
str
:
...
def
pre_attention
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
|
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
|
None
,
ParallelAttentionContext
|
None
]:
"""Runs before the attention kernel.
Returns possibly transformed Q/K/V and metadata, and an optional context
for `post_attention`.
"""
def
post_attention
(
self
,
attn_output
:
torch
.
Tensor
,
ctx
:
ParallelAttentionContext
|
None
,
)
->
torch
.
Tensor
:
"""Runs after the attention kernel."""
class
NoParallelAttention
:
"""Default strategy: do nothing (single device / no SP)."""
@
property
def
enabled
(
self
)
->
bool
:
return
False
@
property
def
name
(
self
)
->
str
:
return
"none"
def
pre_attention
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
|
None
,
):
return
query
,
key
,
value
,
attn_metadata
,
None
def
post_attention
(
self
,
attn_output
:
torch
.
Tensor
,
ctx
:
ParallelAttentionContext
|
None
)
->
torch
.
Tensor
:
return
attn_output
vllm_omni/diffusion/attention/parallel/factory.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
vllm.logger
import
init_logger
from
vllm_omni.diffusion.attention.parallel.base
import
NoParallelAttention
,
ParallelAttentionStrategy
from
vllm_omni.diffusion.attention.parallel.ring
import
RingParallelAttention
from
vllm_omni.diffusion.attention.parallel.ulysses
import
UlyssesParallelAttention
from
vllm_omni.diffusion.distributed.parallel_state
import
get_sequence_parallel_world_size
,
get_sp_group
from
vllm_omni.diffusion.forward_context
import
get_forward_context
logger
=
init_logger
(
__name__
)
def
build_parallel_attention_strategy
(
*
,
scatter_idx
:
int
,
gather_idx
:
int
,
use_sync
:
bool
,
)
->
ParallelAttentionStrategy
:
"""Select a parallel attention strategy based on current diffusion config.
Design principle:
- Attention kernel backend selection remains in `attention/selector.py`.
- Parallel attention selection is handled here, based on distributed config
and initialized process groups.
"""
try
:
cfg
=
get_forward_context
().
omni_diffusion_config
p
=
cfg
.
parallel_config
except
Exception
as
e
:
logger
.
debug
(
f
"No forward context available for parallel attention strategy:
{
e
}
"
)
return
NoParallelAttention
()
ulysses_degree
=
getattr
(
p
,
"ulysses_degree"
,
1
)
ring_degree
=
getattr
(
p
,
"ring_degree"
,
1
)
try
:
sp_group
=
get_sp_group
()
# Ensure SP group is initialized and world size > 1
if
get_sequence_parallel_world_size
()
<=
1
:
return
NoParallelAttention
()
except
Exception
as
e
:
# Log warning if SP is configured but group is not available
if
ulysses_degree
>
1
or
ring_degree
>
1
:
logger
.
warning
(
f
"SP configured (ulysses=
{
ulysses_degree
}
, ring=
{
ring_degree
}
) but SP group not available:
{
e
}
. "
f
"Falling back to NoParallelAttention. This may cause incorrect results."
)
return
NoParallelAttention
()
# Ulysses (or Hybrid Ulysses+Ring)
if
ulysses_degree
>
1
:
logger
.
debug
(
f
"Using UlyssesParallelAttention (ulysses_degree=
{
ulysses_degree
}
)"
)
return
UlyssesParallelAttention
(
sp_group
=
sp_group
,
scatter_idx
=
scatter_idx
,
gather_idx
=
gather_idx
,
use_sync
=
use_sync
,
)
# Pure Ring Attention
if
ring_degree
>
1
:
logger
.
debug
(
f
"Using RingParallelAttention (ring_degree=
{
ring_degree
}
)"
)
return
RingParallelAttention
(
sp_group
=
sp_group
,
)
return
NoParallelAttention
()
vllm_omni/diffusion/attention/parallel/ring.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
import
torch
from
vllm.logger
import
init_logger
# import torch.distributed as dist # Not used directly here, but good practice if needed
from
vllm_omni.diffusion.attention.backends.ring.ring_globals
import
HAS_FA3
,
HAS_FLASH_ATTN
from
vllm_omni.diffusion.attention.backends.ring.ring_selector
import
AttnType
from
vllm_omni.diffusion.attention.parallel.base
import
(
ParallelAttentionContext
,
# ParallelAttentionStrategy, # Not used in type hint below currently
)
from
vllm_omni.diffusion.distributed.group_coordinator
import
SequenceParallelGroupCoordinator
# from vllm_omni.diffusion.attention.backends.ring_selector import AttnType # Already imported above
from
vllm_omni.diffusion.forward_context
import
get_forward_context
if
TYPE_CHECKING
:
from
vllm_omni.diffusion.attention.backends.abstract
import
AttentionMetadata
@
dataclass
(
frozen
=
True
,
slots
=
True
)
class
_RingCtx
(
ParallelAttentionContext
):
"""Per-forward context for Ring sequence-parallel attention."""
# Ring attention typically doesn't need complex context for post-processing
# as the output is already correctly sharded along sequence dimension.
pass
class
RingParallelAttention
:
"""Ring sequence-parallel strategy.
This strategy prepares inputs for Ring Attention.
Key responsibilities:
- Concatenate joint_query (Text) to query (Image) if present.
- Keep joint_key/value separate in metadata for the Ring kernel to handle as static prefix.
"""
def
__init__
(
self
,
sp_group
:
SequenceParallelGroupCoordinator
,
attn_backend_pref
:
str
|
None
=
None
,
)
->
None
:
self
.
_sp_group
=
sp_group
self
.
attn_backend_pref
=
attn_backend_pref
@
property
def
enabled
(
self
)
->
bool
:
return
True
@
property
def
name
(
self
)
->
str
:
return
"ring"
def
pre_attention
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
|
None
,
):
joint_tensor_query
=
None
joint_strategy
=
"front"
if
attn_metadata
is
not
None
:
joint_tensor_query
=
attn_metadata
.
joint_query
joint_strategy
=
attn_metadata
.
joint_strategy
if
joint_tensor_query
is
not
None
:
supported_joint_strategy
=
[
"front"
,
"rear"
]
if
joint_strategy
not
in
supported_joint_strategy
:
raise
ValueError
(
f
"joint_strategy:
{
joint_strategy
}
not supported."
)
if
joint_strategy
==
"front"
:
query
=
torch
.
cat
([
joint_tensor_query
,
query
],
dim
=
1
)
else
:
query
=
torch
.
cat
([
query
,
joint_tensor_query
],
dim
=
1
)
# Note: We do NOT concatenate joint_key/value here.
# They are preserved in attn_metadata and will be passed
# explicitly to ring_flash_attn_func.
ctx
=
_RingCtx
(
name
=
self
.
name
)
return
query
,
key
,
value
,
attn_metadata
,
ctx
def
post_attention
(
self
,
attn_output
:
torch
.
Tensor
,
ctx
:
ParallelAttentionContext
|
None
)
->
torch
.
Tensor
:
# Ring attention output is already sharded correctly along sequence dimension.
return
attn_output
def
run_attention
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
|
None
,
softmax_scale
:
float
|
None
=
None
,
causal
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""Run the actual Ring Attention kernel."""
if
softmax_scale
is
None
:
softmax_scale
=
query
.
shape
[
-
1
]
**
-
0.5
backend_pref
=
self
.
attn_backend_pref
if
backend_pref
is
None
:
try
:
config
=
get_forward_context
().
omni_diffusion_config
# config might not have attention_backend attribute if not updated
backend_pref
=
getattr
(
config
,
"attention_backend"
,
None
)
except
Exception
:
backend_pref
=
None
# Determine attention type with fallback chain: FA3 -> FA2 -> SDPA
# FP32 is not supported by Flash Attention, force SDPA
if
query
.
dtype
==
torch
.
float32
:
backend_pref
=
"sdpa"
elif
not
HAS_FA3
and
not
HAS_FLASH_ATTN
:
if
backend_pref
!=
"sdpa"
:
logger
=
init_logger
(
__name__
)
logger
.
warning_once
(
"Flash Attention (FA2/FA3) is not available! Force enabling SDPA."
)
backend_pref
=
"sdpa"
# Extract joint tensors
joint_key
,
joint_value
=
None
,
None
joint_strategy
=
"front"
if
attn_metadata
is
not
None
:
joint_key
=
attn_metadata
.
joint_key
joint_value
=
attn_metadata
.
joint_value
if
attn_metadata
.
joint_strategy
is
not
None
:
joint_strategy
=
attn_metadata
.
joint_strategy
if
backend_pref
==
"sdpa"
or
backend_pref
==
"torch"
:
from
vllm_omni.diffusion.attention.backends.ring_pytorch_attn
import
ring_pytorch_attn_func
return
ring_pytorch_attn_func
(
query
,
key
,
value
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
group
=
self
.
_sp_group
.
ring_group
,
op_type
=
"efficient"
,
joint_tensor_key
=
joint_key
,
joint_tensor_value
=
joint_value
,
joint_strategy
=
joint_strategy
,
)
from
vllm_omni.diffusion.attention.backends.ring_flash_attn
import
ring_flash_attn_func
# Prefer FA3 over FA2 for better performance (FA3 supports Ampere/Ada/Hopper)
attn_type
=
AttnType
.
FA3
if
HAS_FA3
else
AttnType
.
FA
return
ring_flash_attn_func
(
query
,
key
,
value
,
dropout_p
=
0.0
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
(
-
1
,
-
1
),
softcap
=
0.0
,
alibi_slopes
=
None
,
deterministic
=
False
,
group
=
self
.
_sp_group
.
ring_group
,
attn_type
=
attn_type
,
joint_tensor_key
=
joint_key
,
joint_tensor_value
=
joint_value
,
joint_strategy
=
joint_strategy
,
)
vllm_omni/diffusion/attention/parallel/ulysses.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
dataclasses
import
dataclass
import
torch
import
torch.distributed
as
dist
from
vllm_omni.diffusion.attention.backends.abstract
import
AttentionMetadata
from
vllm_omni.diffusion.attention.parallel.base
import
ParallelAttentionContext
from
vllm_omni.diffusion.distributed.comm
import
SeqAllToAll4D
from
vllm_omni.diffusion.distributed.group_coordinator
import
SequenceParallelGroupCoordinator
@
dataclass
(
frozen
=
True
,
slots
=
True
)
class
_UlyssesCtx
(
ParallelAttentionContext
):
"""Per-forward context for Ulysses sequence-parallel attention."""
ulysses_pg
:
dist
.
ProcessGroup
scatter_idx
:
int
gather_idx
:
int
use_sync
:
bool
joint_len
:
int
=
0
joint_strategy
:
str
=
"front"
class
UlyssesParallelAttention
:
"""Ulysses sequence-parallel strategy (all-to-all over seq/head dims).
This preserves the semantics previously implemented in
`Attention._forward_ulysses`:
- If `AttentionMetadata.joint_*` is provided, joint_query/key/value are
concatenated *after* all-to-all.
- joint_key/value are assumed to be replicated across SP ranks and are sliced
by ulysses head rank before concatenation.
"""
def
__init__
(
self
,
sp_group
:
SequenceParallelGroupCoordinator
,
scatter_idx
:
int
,
gather_idx
:
int
,
use_sync
:
bool
,
)
->
None
:
self
.
_sp_group
=
sp_group
self
.
_ulysses_pg
=
sp_group
.
ulysses_group
self
.
_scatter_idx
=
scatter_idx
self
.
_gather_idx
=
gather_idx
self
.
_use_sync
=
use_sync
@
property
def
enabled
(
self
)
->
bool
:
return
True
@
property
def
name
(
self
)
->
str
:
return
"ulysses"
def
pre_attention
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
|
None
,
):
joint_tensor_query
=
joint_tensor_key
=
joint_tensor_value
=
None
joint_strategy
=
"front"
joint_len
=
0
if
attn_metadata
is
not
None
:
joint_tensor_query
=
attn_metadata
.
joint_query
joint_tensor_key
=
attn_metadata
.
joint_key
joint_tensor_value
=
attn_metadata
.
joint_value
joint_strategy
=
attn_metadata
.
joint_strategy
is_joint
=
False
if
joint_tensor_query
is
not
None
and
joint_tensor_key
is
not
None
and
joint_tensor_value
is
not
None
:
supported_joint_strategy
=
[
"front"
,
"rear"
]
if
joint_strategy
not
in
supported_joint_strategy
:
raise
ValueError
(
f
"joint_strategy:
{
joint_strategy
}
not supported."
f
" supported joint strategy:
{
supported_joint_strategy
}
"
)
# Slice joint_query for this Ulysses rank
# joint_query is (B, S, H, D). We split H (dim 2).
ulysses_world_size
=
self
.
_sp_group
.
ulysses_world_size
ulysses_rank
=
self
.
_sp_group
.
ulysses_rank
attn_heads_per_ulysses_rank
=
joint_tensor_query
.
shape
[
-
2
]
//
ulysses_world_size
# Note: We use the same heads for Q/K/V
joint_tensor_query
=
joint_tensor_query
[
...,
attn_heads_per_ulysses_rank
*
ulysses_rank
:
attn_heads_per_ulysses_rank
*
(
ulysses_rank
+
1
),
:,
]
joint_len
=
joint_tensor_query
.
shape
[
1
]
is_joint
=
True
elif
joint_tensor_query
is
None
and
joint_tensor_key
is
None
and
joint_tensor_value
is
None
:
pass
else
:
raise
ValueError
(
"joint_query, joint_key, and joint_value should be None or not None simultaneously."
)
if
is_joint
:
# Slice joint key/value heads for this ulysses rank.
# Using same slicing logic as query
attn_heads_per_ulysses_rank_kv
=
joint_tensor_key
.
shape
[
-
2
]
//
ulysses_world_size
joint_tensor_key
=
joint_tensor_key
[
...,
attn_heads_per_ulysses_rank_kv
*
ulysses_rank
:
attn_heads_per_ulysses_rank_kv
*
(
ulysses_rank
+
1
),
:,
]
joint_tensor_value
=
joint_tensor_value
[
...,
attn_heads_per_ulysses_rank_kv
*
ulysses_rank
:
attn_heads_per_ulysses_rank_kv
*
(
ulysses_rank
+
1
),
:,
]
# Update metadata with sliced tensors so Ring attention can use them if needed
if
attn_metadata
is
not
None
:
attn_metadata
.
joint_key
=
joint_tensor_key
attn_metadata
.
joint_value
=
joint_tensor_value
# (bs, seq_len/P, head_cnt, head_size) -> (bs, seq_len, head_cnt/P, head_size)
query
=
SeqAllToAll4D
.
apply
(
self
.
_ulysses_pg
,
query
,
self
.
_scatter_idx
,
self
.
_gather_idx
,
self
.
_use_sync
)
key
=
SeqAllToAll4D
.
apply
(
self
.
_ulysses_pg
,
key
,
self
.
_scatter_idx
,
self
.
_gather_idx
,
self
.
_use_sync
)
value
=
SeqAllToAll4D
.
apply
(
self
.
_ulysses_pg
,
value
,
self
.
_scatter_idx
,
self
.
_gather_idx
,
self
.
_use_sync
)
if
is_joint
:
# Concatenate joint query AFTER AllToAll
# Image query is now (B, S, H/P, D). Joint query is (B, S_txt, H/P, D).
# This is dimensionally consistent.
if
joint_strategy
==
"rear"
:
query
=
torch
.
cat
([
query
,
joint_tensor_query
],
dim
=
1
)
else
:
query
=
torch
.
cat
([
joint_tensor_query
,
query
],
dim
=
1
)
# Check if Ring Attention is also active (Hybrid mode)
# If Ring is active, we should NOT concatenate joint_key/value to k/v here.
# Instead, they should remain in attn_metadata and be passed to the Ring kernel.
use_ring
=
self
.
_sp_group
.
ring_world_size
>
1
if
is_joint
and
not
use_ring
:
# Concatenate joint key/value after all-to-all ONLY for pure Ulysses (Local Attention).
if
joint_strategy
==
"front"
:
key
=
torch
.
cat
([
joint_tensor_key
,
key
],
dim
=
1
)
value
=
torch
.
cat
([
joint_tensor_value
,
value
],
dim
=
1
)
else
:
# "rear"
key
=
torch
.
cat
([
key
,
joint_tensor_key
],
dim
=
1
)
value
=
torch
.
cat
([
value
,
joint_tensor_value
],
dim
=
1
)
ctx
=
_UlyssesCtx
(
name
=
self
.
name
,
ulysses_pg
=
self
.
_ulysses_pg
,
scatter_idx
=
self
.
_scatter_idx
,
gather_idx
=
self
.
_gather_idx
,
use_sync
=
self
.
_use_sync
,
joint_len
=
joint_len
,
joint_strategy
=
joint_strategy
,
)
if
attn_metadata
is
not
None
:
if
is_joint
:
if
attn_metadata
.
joint_attn_mask
is
None
and
attn_metadata
.
attn_mask
is
None
:
attn_metadata
.
attn_mask
=
None
else
:
if
attn_metadata
.
attn_mask
is
None
:
attn_metadata
.
attn_mask
=
torch
.
ones
(
[
query
.
shape
[
0
],
query
.
shape
[
1
]
-
attn_metadata
.
joint_attn_mask
.
shape
[
1
]],
dtype
=
torch
.
bool
,
device
=
query
.
device
,
)
elif
attn_metadata
.
joint_attn_mask
is
None
:
attn_metadata
.
joint_attn_mask
=
torch
.
ones
(
[
query
.
shape
[
0
],
query
.
shape
[
1
]
-
attn_metadata
.
attn_mask
.
shape
[
1
]],
dtype
=
torch
.
bool
,
device
=
query
.
device
,
)
attn_metadata
.
attn_mask
=
(
torch
.
cat
([
attn_metadata
.
joint_attn_mask
,
attn_metadata
.
attn_mask
],
dim
=
1
)
if
joint_strategy
==
"front"
else
torch
.
cat
([
attn_metadata
.
attn_mask
,
attn_metadata
.
joint_attn_mask
],
dim
=
1
)
)
if
attn_metadata
.
attn_mask
is
not
None
:
# the final attn_mask is ready, the length should be aligedn with query length
assert
attn_metadata
.
attn_mask
.
shape
[
1
]
==
query
.
shape
[
1
],
(
f
"attn_mask length:
{
attn_metadata
.
attn_mask
.
shape
[
1
]
}
!= query length:
{
query
.
shape
[
1
]
}
"
)
attn_metadata
.
attn_mask
=
attn_metadata
.
attn_mask
.
bool
().
contiguous
()
return
query
,
key
,
value
,
attn_metadata
,
ctx
def
post_attention
(
self
,
attn_output
:
torch
.
Tensor
,
ctx
:
ParallelAttentionContext
|
None
)
->
torch
.
Tensor
:
assert
isinstance
(
ctx
,
_UlyssesCtx
),
f
"Unexpected ctx type:
{
type
(
ctx
)
!
r
}
"
# If we have joint tensors (Text), they were Head-Sliced.
# The main sequence (Image) was Sequence-Sliced.
# attn_output contains [Joint_Sliced | Image_Sliced] (if strategy='front').
if
ctx
.
joint_len
>
0
:
joint_len
=
ctx
.
joint_len
if
ctx
.
joint_strategy
==
"front"
:
output_joint
=
attn_output
[:,
:
joint_len
]
output_img
=
attn_output
[:,
joint_len
:]
else
:
output_img
=
attn_output
[:,
:
-
joint_len
]
output_joint
=
attn_output
[:,
-
joint_len
:]
# 1. Process Image part: Standard Ulysses Reverse (AllToAll)
# (bs, seq_len, head_cnt/P, head_size) -> (bs, seq_len/P, head_cnt, head_size)
# SeqAllToAll4D handles: Scatter gather_idx, Gather scatter_idx.
# Forward: Scatter 2 (H), Gather 1 (S).
# Reverse: Scatter 1 (S), Gather 2 (H).
output_img
=
SeqAllToAll4D
.
apply
(
ctx
.
ulysses_pg
,
output_img
,
ctx
.
gather_idx
,
ctx
.
scatter_idx
,
ctx
.
use_sync
)
# 2. Process Joint part: AllGather on Heads
# Input: (B, JointLen, H/P, D). Output: (B, JointLen, H, D).
# AllGather along dim 2.
# Ensure tensor is contiguous for all_gather (slicing may create non-contiguous views)
output_joint
=
output_joint
.
contiguous
()
gathered_joint
=
[
torch
.
zeros_like
(
output_joint
)
for
_
in
range
(
dist
.
get_world_size
(
ctx
.
ulysses_pg
))]
dist
.
all_gather
(
gathered_joint
,
output_joint
,
group
=
ctx
.
ulysses_pg
)
output_joint
=
torch
.
cat
(
gathered_joint
,
dim
=
2
)
# 3. Recombine
if
ctx
.
joint_strategy
==
"front"
:
return
torch
.
cat
([
output_joint
,
output_img
],
dim
=
1
)
else
:
return
torch
.
cat
([
output_img
,
output_joint
],
dim
=
1
)
# Standard Ulysses Reverse
return
SeqAllToAll4D
.
apply
(
ctx
.
ulysses_pg
,
attn_output
,
ctx
.
gather_idx
,
ctx
.
scatter_idx
,
ctx
.
use_sync
)
vllm_omni/diffusion/attention/selector.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Diffusion attention backend selector.
This module provides the interface for selecting diffusion attention backends.
The actual backend selection logic is delegated to the platform layer
(vllm_omni.platforms), similar to how vLLM handles attention backend selection.
Usage:
from vllm_omni.diffusion.attention.selector import get_attn_backend
# Get the appropriate backend for current platform
backend_cls = get_attn_backend(head_size=64)
# Or override via environment variable
# export DIFFUSION_ATTENTION_BACKEND=FLASH_ATTN
"""
import
importlib
import
os
from
functools
import
cache
from
vllm.logger
import
init_logger
from
vllm_omni.diffusion.attention.backends.abstract
import
(
AttentionBackend
,
)
logger
=
init_logger
(
__name__
)
def
_load_backend_cls
(
cls_path
:
str
)
->
type
[
AttentionBackend
]:
"""Load a backend class from its fully qualified path.
Args:
cls_path: Fully qualified class path (e.g.,
"vllm_omni.diffusion.attention.backends.sdpa.SDPABackend")
Returns:
The loaded backend class
"""
module_path
,
class_name
=
cls_path
.
rsplit
(
"."
,
1
)
try
:
module
=
importlib
.
import_module
(
module_path
)
backend_class
=
getattr
(
module
,
class_name
)
return
backend_class
except
ImportError
as
e
:
raise
ImportError
(
f
"Failed to import module
{
module_path
}
:
{
e
}
"
)
except
AttributeError
as
e
:
raise
AttributeError
(
f
"Class
{
class_name
}
not found in module:
{
e
}
"
)
@
cache
def
get_attn_backend
(
head_size
:
int
)
->
type
[
AttentionBackend
]:
"""
Get attention backend for diffusion models.
The backend selection is delegated to the current platform
(vllm_omni.platforms.current_omni_platform), which selects the
appropriate backend based on:
1. User override via DIFFUSION_ATTENTION_BACKEND environment variable
2. Platform-specific defaults and capabilities
This is similar to how vLLM's get_attn_backend_cls works, where the
platform layer decides which backend to use based on hardware capabilities.
Args:
head_size: Head size for attention computation (may affect backend selection)
Returns:
The selected attention backend class
"""
from
vllm_omni.platforms
import
current_omni_platform
# Check environment variable for user override
selected_backend
=
os
.
environ
.
get
(
"DIFFUSION_ATTENTION_BACKEND"
)
# Delegate to platform for backend selection
backend_cls_path
=
current_omni_platform
.
get_diffusion_attn_backend_cls
(
selected_backend
=
selected_backend
,
head_size
=
head_size
,
)
return
_load_backend_cls
(
backend_cls_path
)
vllm_omni/diffusion/cache/__init__.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Cache module for diffusion model inference acceleration.
This module provides a unified cache backend system for different caching strategies:
- TeaCache: Timestep Embedding Aware Cache for adaptive transformer caching
- cache-dit: DBCache, SCM, and TaylorSeer caching strategies
Cache backends are instantiated directly via their constructors and configured via OmniDiffusionConfig.
"""
from
vllm_omni.diffusion.cache.base
import
CacheBackend
from
vllm_omni.diffusion.cache.teacache
import
(
CacheContext
,
TeaCacheConfig
,
apply_teacache_hook
,
)
from
vllm_omni.diffusion.cache.teacache.backend
import
TeaCacheBackend
__all__
=
[
"CacheBackend"
,
"TeaCacheConfig"
,
"CacheContext"
,
"TeaCacheBackend"
,
"apply_teacache_hook"
,
]
vllm_omni/diffusion/cache/base.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Base cache backend interface for diffusion models.
This module defines the abstract base class that all cache backends must implement.
Cache backends provide a unified interface for applying different caching strategies
to transformer models.
Main cache backend implementations:
1. CacheDiTBackend: Implements cache-dit acceleration (DBCache, SCM, TaylorSeer) using
the cache-dit library. Inherits from CacheBackend. Used via cache_backend="cache_dit".
2. TeaCacheBackend: Hook-based backend for TeaCache acceleration. Inherits from
CacheBackend. Used via cache_backend="tea_cache".
All backends implement the same interface:
- enable(pipeline): Enable cache on the pipeline
- refresh(pipeline, num_inference_steps, verbose): Refresh cache state
- is_enabled(): Check if cache is enabled
"""
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
import
torch.nn
as
nn
from
vllm_omni.diffusion.data
import
DiffusionCacheConfig
class
CacheBackend
(
ABC
):
"""
Abstract base class for cache backends.
All cache backend implementations (CacheDiTBackend, TeaCacheBackend, etc.) inherit
from this base class and implement the enable() and refresh() methods to manage
cache lifecycle.
Cache backends apply caching strategies to transformer models to accelerate
inference. Different backends use different underlying mechanisms (e.g., cache-dit
library for CacheDiTBackend, hooks for TeaCacheBackend), but all share the same
unified interface.
Attributes:
config: DiffusionCacheConfig instance containing cache-specific configuration parameters
enabled: Boolean flag indicating whether cache is enabled (set to True after enable() is called)
"""
def
__init__
(
self
,
config
:
DiffusionCacheConfig
):
"""
Initialize cache backend with configuration.
Args:
config: DiffusionCacheConfig instance with cache-specific parameters
"""
self
.
config
=
config
self
.
enabled
=
False
@
abstractmethod
def
enable
(
self
,
pipeline
:
Any
)
->
None
:
"""
Enable cache on the pipeline.
This method applies the caching strategy to the transformer(s) in the pipeline.
The specific implementation depends on the backend (e.g., hooks for TeaCacheBackend,
cache-dit library for CacheDiTBackend). Called once during pipeline initialization.
Args:
pipeline: Diffusion pipeline instance. The backend can extract:
- transformer: via pipeline.transformer
- model_type: via pipeline.__class__.__name__
"""
raise
NotImplementedError
(
"Subclasses must implement enable()"
)
@
abstractmethod
def
refresh
(
self
,
pipeline
:
Any
,
num_inference_steps
:
int
,
verbose
:
bool
=
True
)
->
None
:
"""
Refresh cache state for new generation.
This method should clear any cached values and reset counters/accumulators.
Called at the start of each generation to ensure clean state.
Args:
pipeline: Diffusion pipeline instance. The backend can extract:
- transformer: via pipeline.transformer
num_inference_steps: Number of inference steps for the current generation.
May be used for cache context updates.
verbose: Whether to log refresh operations (default: True)
"""
raise
NotImplementedError
(
"Subclasses must implement refresh()"
)
def
is_enabled
(
self
)
->
bool
:
"""
Check if cache is enabled on this backend.
Returns:
True if cache is enabled, False otherwise.
"""
return
self
.
enabled
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(config=
{
self
.
config
}
)"
class
CachedTransformer
(
nn
.
Module
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
self
.
do_true_cfg
=
False
def
__init_subclass__
(
cls
,
enable_separate_cfg
:
bool
=
True
,
**
kwargs
):
cls
.
enable_separate_cfg
=
enable_separate_cfg
super
().
__init_subclass__
(
**
kwargs
)
vllm_omni/diffusion/cache/cache_dit_backend.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
cache-dit integration backend for vllm-omni.
This module provides a CacheDiTBackend class to enable cache-dit acceleration on diffusion
pipelines in vllm-omni, supporting both single and dual-transformer architectures.
"""
import
functools
from
collections.abc
import
Callable
from
contextlib
import
ExitStack
from
typing
import
Any
,
Optional
import
cache_dit
import
torch
from
cache_dit
import
BlockAdapter
,
DBCacheConfig
,
ForwardPattern
,
ParamsModifier
,
TaylorSeerCalibratorConfig
from
cache_dit.caching.block_adapters
import
FakeDiffusionPipeline
from
cache_dit.caching.cache_adapters.cache_adapter
import
CachedAdapter
from
cache_dit.caching.cache_blocks.pattern_0_1_2
import
CachedBlocks_Pattern_0_1_2
from
cache_dit.caching.cache_contexts
import
BasicCacheConfig
from
cache_dit.caching.cache_contexts.cache_manager
import
CachedContextManager
from
vllm.logger
import
init_logger
from
vllm_omni.diffusion.cache.base
import
CacheBackend
from
vllm_omni.diffusion.data
import
DiffusionCacheConfig
,
OmniDiffusionConfig
logger
=
init_logger
(
__name__
)
# Small helper to centralize cache-dit summaries.
def
cache_summary
(
pipeline
:
Any
,
details
:
bool
=
True
)
->
None
:
cache_dit
.
summary
(
pipeline
.
transformer
,
details
=
details
)
if
hasattr
(
pipeline
,
"transformer_2"
):
cache_dit
.
summary
(
pipeline
.
transformer_2
,
details
=
details
)
# Registry of custom cache-dit enablers for specific models
# Maps pipeline names to their cache-dit enablement functions
# Models in this registry require custom handling (e.g., dual-transformer architectures)
# Will be populated after function definitions
CUSTOM_DIT_ENABLERS
:
dict
[
str
,
Callable
]
=
{}
def
_build_db_cache_config
(
cache_config
:
Any
)
->
DBCacheConfig
:
"""Build DBCacheConfig with optional SCM (Step Computation Masking) support.
Args:
cache_config: DiffusionCacheConfig instance.
Returns:
DBCacheConfig instance with SCM support if configured.
"""
return
DBCacheConfig
(
# we will refresh the context when gets num_inference_steps in the first inference request
num_inference_steps
=
None
,
Fn_compute_blocks
=
cache_config
.
Fn_compute_blocks
,
Bn_compute_blocks
=
cache_config
.
Bn_compute_blocks
,
max_warmup_steps
=
cache_config
.
max_warmup_steps
,
max_cached_steps
=
cache_config
.
max_cached_steps
,
max_continuous_cached_steps
=
cache_config
.
max_continuous_cached_steps
,
residual_diff_threshold
=
cache_config
.
residual_diff_threshold
,
)
def
enable_cache_for_wan22
(
pipeline
:
Any
,
cache_config
:
Any
)
->
Callable
[[
int
],
None
]:
"""Enable cache-dit for Wan2.2 dual-transformer architecture.
Wan2.2 uses two transformers (transformer and transformer_2) that need
to be enabled together using BlockAdapter.
Args:
pipeline: The Wan2.2 pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called to update cache context with new num_inference_steps.
"""
cache_dit
.
enable_cache
(
BlockAdapter
(
transformer
=
[
pipeline
.
transformer
,
pipeline
.
transformer_2
,
],
blocks
=
[
pipeline
.
transformer
.
blocks
,
pipeline
.
transformer_2
.
blocks
,
],
forward_pattern
=
[
ForwardPattern
.
Pattern_2
,
ForwardPattern
.
Pattern_2
,
],
params_modifiers
=
[
# high-noise transformer only have 30% steps
ParamsModifier
(
cache_config
=
DBCacheConfig
().
reset
(
max_warmup_steps
=
cache_config
.
max_warmup_steps
,
max_cached_steps
=
cache_config
.
max_cached_steps
,
),
),
ParamsModifier
(
cache_config
=
DBCacheConfig
().
reset
(
max_warmup_steps
=
2
,
max_cached_steps
=
20
,
),
),
],
has_separate_cfg
=
True
,
),
cache_config
=
DBCacheConfig
(
Fn_compute_blocks
=
cache_config
.
Fn_compute_blocks
,
Bn_compute_blocks
=
cache_config
.
Bn_compute_blocks
,
max_warmup_steps
=
cache_config
.
max_warmup_steps
,
max_cached_steps
=
cache_config
.
max_cached_steps
,
max_continuous_cached_steps
=
cache_config
.
max_continuous_cached_steps
,
residual_diff_threshold
=
cache_config
.
residual_diff_threshold
,
num_inference_steps
=
None
,
),
)
# from https://github.com/vipshop/cache-dit/pull/542
def
_split_inference_steps
(
num_inference_steps
:
int
)
->
tuple
[
int
,
int
]:
"""Split inference steps into high-noise and low-noise steps for Wan2.2.
This is an internal helper function specific to Wan2.2's dual-transformer
architecture that uses boundary_ratio to determine the split point.
Args:
num_inference_steps: Total number of inference steps.
Returns:
A tuple of (num_high_noise_steps, num_low_noise_steps).
"""
if
pipeline
.
boundary_ratio
is
not
None
:
boundary_timestep
=
pipeline
.
boundary_ratio
*
pipeline
.
scheduler
.
config
.
num_train_timesteps
else
:
boundary_timestep
=
None
# Set timesteps to calculate the split
device
=
next
(
pipeline
.
transformer
.
parameters
()).
device
pipeline
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
device
)
timesteps
=
pipeline
.
scheduler
.
timesteps
num_high_noise_steps
=
0
# high-noise steps for transformer
for
t
in
timesteps
:
if
boundary_timestep
is
None
or
t
>=
boundary_timestep
:
num_high_noise_steps
+=
1
# low-noise steps for transformer_2
num_low_noise_steps
=
num_inference_steps
-
num_high_noise_steps
return
num_high_noise_steps
,
num_low_noise_steps
def
refresh_cache_context
(
pipeline
:
Any
,
num_inference_steps
:
int
,
verbose
:
bool
=
True
)
->
None
:
"""Refresh cache context for both transformers with new num_inference_steps.
Args:
pipeline: The Wan2.2 pipeline instance.
num_inference_steps: New number of inference steps.
"""
num_high_noise_steps
,
num_low_noise_steps
=
_split_inference_steps
(
num_inference_steps
)
# Refresh context for high-noise transformer
if
cache_config
.
scm_steps_mask_policy
is
None
:
# cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_high_noise_steps, verbose=verbose)
cache_dit
.
refresh_context
(
pipeline
.
transformer
,
num_inference_steps
=
num_high_noise_steps
,
verbose
=
verbose
,
)
cache_dit
.
refresh_context
(
pipeline
.
transformer_2
,
num_inference_steps
=
num_low_noise_steps
,
verbose
=
verbose
,
)
else
:
cache_dit
.
refresh_context
(
pipeline
.
transformer
,
cache_config
=
DBCacheConfig
().
reset
(
num_inference_steps
=
num_high_noise_steps
,
steps_computation_mask
=
cache_dit
.
steps_mask
(
mask_policy
=
cache_config
.
scm_steps_mask_policy
,
total_steps
=
num_high_noise_steps
),
steps_computation_policy
=
cache_config
.
scm_steps_policy
,
),
verbose
=
verbose
,
)
cache_dit
.
refresh_context
(
pipeline
.
transformer_2
,
cache_config
=
DBCacheConfig
().
reset
(
num_inference_steps
=
num_low_noise_steps
,
steps_computation_mask
=
cache_dit
.
steps_mask
(
mask_policy
=
cache_config
.
scm_steps_mask_policy
,
total_steps
=
num_low_noise_steps
),
steps_computation_policy
=
cache_config
.
scm_steps_policy
,
),
verbose
=
verbose
,
)
return
refresh_cache_context
def
enable_cache_for_longcat_image
(
pipeline
:
Any
,
cache_config
:
Any
)
->
Callable
[[
int
],
None
]:
"""Enable cache-dit for LongCatImage pipeline.
Args:
pipeline: The LongCatImage pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
"""
# Build DBCacheConfig for transformer
db_cache_config
=
_build_db_cache_config
(
cache_config
)
calibrator
=
None
if
cache_config
.
enable_taylorseer
:
taylorseer_order
=
cache_config
.
taylorseer_order
calibrator
=
TaylorSeerCalibratorConfig
(
taylorseer_order
=
taylorseer_order
)
logger
.
info
(
f
"TaylorSeer enabled with order=
{
taylorseer_order
}
"
)
# Build ParamsModifier for transformer
modifier
=
ParamsModifier
(
cache_config
=
db_cache_config
,
calibrator_config
=
calibrator
,
)
logger
.
info
(
f
"Enabling cache-dit on LongCatImage transformer with BlockAdapter: "
f
"Fn=
{
db_cache_config
.
Fn_compute_blocks
}
, "
f
"Bn=
{
db_cache_config
.
Bn_compute_blocks
}
, "
f
"W=
{
db_cache_config
.
max_warmup_steps
}
, "
)
# Enable cache-dit using BlockAdapter for transformer
cache_dit
.
enable_cache
(
(
BlockAdapter
(
transformer
=
pipeline
.
transformer
,
blocks
=
[
pipeline
.
transformer
.
transformer_blocks
,
pipeline
.
transformer
.
single_transformer_blocks
,
],
forward_pattern
=
[
ForwardPattern
.
Pattern_1
,
ForwardPattern
.
Pattern_1
],
params_modifiers
=
[
modifier
],
)
),
cache_config
=
db_cache_config
,
)
def
refresh_cache_context
(
pipeline
:
Any
,
num_inference_steps
:
int
,
verbose
:
bool
=
True
)
->
None
:
"""Refresh cache context for the transformer with new num_inference_steps.
Args:
pipeline: The LongCatImage pipeline instance.
num_inference_steps: New number of inference steps.
"""
if
cache_config
.
scm_steps_mask_policy
is
None
:
cache_dit
.
refresh_context
(
pipeline
.
transformer
,
num_inference_steps
=
num_inference_steps
,
verbose
=
verbose
)
else
:
cache_dit
.
refresh_context
(
pipeline
.
transformer
,
cache_config
=
DBCacheConfig
().
reset
(
num_inference_steps
=
num_inference_steps
,
steps_computation_mask
=
cache_dit
.
steps_mask
(
mask_policy
=
cache_config
.
scm_steps_mask_policy
,
total_steps
=
num_inference_steps
,
),
steps_computation_policy
=
cache_config
.
scm_steps_policy
,
),
verbose
=
verbose
,
)
return
refresh_cache_context
def
enable_cache_for_flux
(
pipeline
:
Any
,
cache_config
:
Any
)
->
Callable
[[
int
],
None
]:
"""Enable cache-dit for Flux dual-transformer architecture.
Flux uses two transformers (transformer and transformer_2) that need
to be enabled together using BlockAdapter.
Args:
pipeline: The Flux pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called to update cache context with new num_inference_steps.
"""
raise
NotImplementedError
(
"cache-dit is not implemented for Flux pipeline."
)
def
enable_cache_for_sd3
(
pipeline
:
Any
,
cache_config
:
Any
)
->
Callable
[[
int
],
None
]:
"""Enable cache-dit for StableDiffusion3Pipeline.
Args:
pipeline: The StableDiffusion3 pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
"""
# Build DBCacheConfig for transformer
db_cache_config
=
_build_db_cache_config
(
cache_config
)
calibrator
=
None
if
cache_config
.
enable_taylorseer
:
taylorseer_order
=
cache_config
.
taylorseer_order
calibrator
=
TaylorSeerCalibratorConfig
(
taylorseer_order
=
taylorseer_order
)
logger
.
info
(
f
"TaylorSeer enabled with order=
{
taylorseer_order
}
"
)
# Build ParamsModifier for transformer
modifier
=
ParamsModifier
(
cache_config
=
db_cache_config
,
calibrator_config
=
calibrator
,
)
logger
.
info
(
f
"Enabling cache-dit on StableDiffusion3 transformer with BlockAdapter: "
f
"Fn=
{
db_cache_config
.
Fn_compute_blocks
}
, "
f
"Bn=
{
db_cache_config
.
Bn_compute_blocks
}
, "
f
"W=
{
db_cache_config
.
max_warmup_steps
}
, "
)
# Enable cache-dit using BlockAdapter for transformer
cache_dit
.
enable_cache
(
(
BlockAdapter
(
transformer
=
pipeline
.
transformer
,
blocks
=
pipeline
.
transformer
.
transformer_blocks
,
forward_pattern
=
ForwardPattern
.
Pattern_1
,
params_modifiers
=
[
modifier
],
)
),
cache_config
=
db_cache_config
,
)
def
refresh_cache_context
(
pipeline
:
Any
,
num_inference_steps
:
int
,
verbose
:
bool
=
True
)
->
None
:
"""Refresh cache context for the transformer with new num_inference_steps.
Args:
pipeline: The LongCatImage pipeline instance.
num_inference_steps: New number of inference steps.
"""
if
cache_config
.
scm_steps_mask_policy
is
None
:
cache_dit
.
refresh_context
(
pipeline
.
transformer
,
num_inference_steps
=
num_inference_steps
,
verbose
=
verbose
)
else
:
cache_dit
.
refresh_context
(
pipeline
.
transformer
,
cache_config
=
DBCacheConfig
().
reset
(
num_inference_steps
=
num_inference_steps
,
steps_computation_mask
=
cache_dit
.
steps_mask
(
mask_policy
=
cache_config
.
scm_steps_mask_policy
,
total_steps
=
num_inference_steps
,
),
steps_computation_policy
=
cache_config
.
scm_steps_policy
,
),
verbose
=
verbose
,
)
return
refresh_cache_context
def
enable_cache_for_dit
(
pipeline
:
Any
,
cache_config
:
Any
)
->
Callable
[[
int
],
None
]:
"""Enable cache-dit for regular single-transformer DiT models.
Args:
pipeline: The diffusion pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called to update cache context with new num_inference_steps.
"""
# Build DBCacheConfig with optional SCM support
db_cache_config
=
_build_db_cache_config
(
cache_config
)
# Build calibrator config if TaylorSeer is enabled
calibrator_config
=
None
if
cache_config
.
enable_taylorseer
:
taylorseer_order
=
cache_config
.
taylorseer_order
calibrator_config
=
TaylorSeerCalibratorConfig
(
taylorseer_order
=
taylorseer_order
)
logger
.
info
(
f
"TaylorSeer enabled with order=
{
taylorseer_order
}
"
)
logger
.
info
(
f
"Enabling cache-dit on transformer: "
f
"Fn=
{
db_cache_config
.
Fn_compute_blocks
}
, "
f
"Bn=
{
db_cache_config
.
Bn_compute_blocks
}
, "
f
"W=
{
db_cache_config
.
max_warmup_steps
}
, "
)
# Enable cache-dit on the transformer
cache_dit
.
enable_cache
(
pipeline
.
transformer
,
cache_config
=
db_cache_config
,
calibrator_config
=
calibrator_config
,
)
def
refresh_cache_context
(
pipeline
:
Any
,
num_inference_steps
:
int
,
verbose
:
bool
=
True
)
->
None
:
"""Refresh cache context for the transformer with new num_inference_steps.
Args:
pipeline: The diffusion pipeline instance.
num_inference_steps: New number of inference steps.
"""
if
cache_config
.
scm_steps_mask_policy
is
None
:
cache_dit
.
refresh_context
(
pipeline
.
transformer
,
num_inference_steps
=
num_inference_steps
,
verbose
=
verbose
)
else
:
cache_dit
.
refresh_context
(
pipeline
.
transformer
,
cache_config
=
DBCacheConfig
().
reset
(
num_inference_steps
=
num_inference_steps
,
steps_computation_mask
=
cache_dit
.
steps_mask
(
mask_policy
=
cache_config
.
scm_steps_mask_policy
,
total_steps
=
num_inference_steps
,
),
steps_computation_policy
=
cache_config
.
scm_steps_policy
,
),
verbose
=
verbose
,
)
return
refresh_cache_context
class
BagelCachedContextManager
(
CachedContextManager
):
"""
Custom CachedContextManager for Bagel that safely handles NaiveCache objects
(mapped to encoder_hidden_states) by skipping tensor operations on them.
"""
@
torch
.
compiler
.
disable
def
apply_cache
(
self
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
=
None
,
prefix
:
str
=
"Bn"
,
encoder_prefix
:
str
=
"Bn_encoder"
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
# Allow Bn and Fn prefix to be used for residual cache.
if
"Bn"
in
prefix
:
hidden_states_prev
=
self
.
get_Bn_buffer
(
prefix
)
else
:
hidden_states_prev
=
self
.
get_Fn_buffer
(
prefix
)
assert
hidden_states_prev
is
not
None
,
f
"
{
prefix
}
_buffer must be set before"
if
self
.
is_cache_residual
():
hidden_states
=
hidden_states_prev
+
hidden_states
else
:
# If cache is not residual, we use the hidden states directly
hidden_states
=
hidden_states_prev
hidden_states
=
hidden_states
.
contiguous
()
if
encoder_hidden_states
is
not
None
:
if
"Bn"
in
encoder_prefix
:
encoder_hidden_states_prev
=
self
.
get_Bn_encoder_buffer
(
encoder_prefix
)
else
:
encoder_hidden_states_prev
=
self
.
get_Fn_encoder_buffer
(
encoder_prefix
)
if
encoder_hidden_states_prev
is
not
None
:
if
self
.
is_encoder_cache_residual
():
# FIX: Check if encoder_hidden_states is a tensor before adding
if
isinstance
(
encoder_hidden_states
,
torch
.
Tensor
)
and
isinstance
(
encoder_hidden_states_prev
,
torch
.
Tensor
):
encoder_hidden_states
=
encoder_hidden_states_prev
+
encoder_hidden_states
else
:
# If encoder cache is not residual, we use the encoder hidden states directly
encoder_hidden_states
=
encoder_hidden_states_prev
# FIX: Check if encoder_hidden_states is a tensor before calling contiguous
if
isinstance
(
encoder_hidden_states
,
torch
.
Tensor
):
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
return
hidden_states
,
encoder_hidden_states
class
BagelCachedBlocks
(
CachedBlocks_Pattern_0_1_2
):
"""
Custom CachedBlocks for Bagel that safely handles NaiveCache objects
by adding isinstance checks in call_Mn_blocks and compute_or_prune.
"""
def
call_Mn_blocks
(
self
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
*
args
,
**
kwargs
,
):
original_hidden_states
=
hidden_states
original_encoder_hidden_states
=
encoder_hidden_states
for
block
in
self
.
_Mn_blocks
():
hidden_states
=
block
(
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
,
)
hidden_states
,
encoder_hidden_states
=
self
.
_process_block_outputs
(
hidden_states
,
encoder_hidden_states
)
# compute hidden_states residual
hidden_states
=
hidden_states
.
contiguous
()
hidden_states_residual
=
hidden_states
-
original_hidden_states
if
(
encoder_hidden_states
is
not
None
and
original_encoder_hidden_states
is
not
None
and
isinstance
(
encoder_hidden_states
,
torch
.
Tensor
)
# FIX: Added Check
):
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
encoder_hidden_states_residual
=
encoder_hidden_states
-
original_encoder_hidden_states
else
:
encoder_hidden_states_residual
=
None
return
(
hidden_states
,
encoder_hidden_states
,
hidden_states_residual
,
encoder_hidden_states_residual
,
)
def
compute_or_prune
(
self
,
block_id
:
int
,
# Block index in the transformer blocks
# Below are the inputs to the block
block
,
# The transformer block to be executed
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
*
args
,
**
kwargs
,
):
# NOTE: Although Bagel likely won't use pruning, implementing safe version just in case.
# Copy-pasted from original but adding checks.
original_hidden_states
=
hidden_states
original_encoder_hidden_states
=
encoder_hidden_states
can_use_prune
=
self
.
_maybe_prune
(
block_id
,
hidden_states
,
prefix
=
f
"
{
self
.
cache_prefix
}
_
{
block_id
}
_Fn_original"
,
)
torch
.
_dynamo
.
graph_break
()
if
can_use_prune
:
self
.
context_manager
.
add_pruned_step
()
hidden_states
,
encoder_hidden_states
=
self
.
context_manager
.
apply_prune
(
hidden_states
,
encoder_hidden_states
,
prefix
=
(
f
"
{
self
.
cache_prefix
}
_
{
block_id
}
_Bn_residual"
if
self
.
context_manager
.
is_cache_residual
()
else
f
"
{
self
.
cache_prefix
}
_Bn_hidden_states"
),
encoder_prefix
=
(
f
"
{
self
.
cache_prefix
}
_
{
block_id
}
_Bn_encoder_residual"
if
self
.
context_manager
.
is_encoder_cache_residual
()
else
f
"
{
self
.
cache_prefix
}
_
{
block_id
}
_Bn_encoder_hidden_states"
),
)
torch
.
_dynamo
.
graph_break
()
else
:
# Normal steps: Compute the block and cache the residuals.
hidden_states
=
block
(
hidden_states
,
encoder_hidden_states
,
*
args
,
**
kwargs
,
)
hidden_states
,
encoder_hidden_states
=
self
.
_process_block_outputs
(
hidden_states
,
encoder_hidden_states
)
if
not
self
.
_skip_prune
(
block_id
):
hidden_states
=
hidden_states
.
contiguous
()
hidden_states_residual
=
hidden_states
-
original_hidden_states
if
(
encoder_hidden_states
is
not
None
and
original_encoder_hidden_states
is
not
None
and
isinstance
(
encoder_hidden_states
,
torch
.
Tensor
)
# FIX: Added Check
):
encoder_hidden_states
=
encoder_hidden_states
.
contiguous
()
encoder_hidden_states_residual
=
encoder_hidden_states
-
original_encoder_hidden_states
else
:
encoder_hidden_states_residual
=
None
self
.
context_manager
.
set_Fn_buffer
(
original_hidden_states
,
prefix
=
f
"
{
self
.
cache_prefix
}
_
{
block_id
}
_Fn_original"
,
)
if
self
.
context_manager
.
is_cache_residual
():
self
.
context_manager
.
set_Bn_buffer
(
hidden_states_residual
,
prefix
=
f
"
{
self
.
cache_prefix
}
_
{
block_id
}
_Bn_residual"
,
)
else
:
self
.
context_manager
.
set_Bn_buffer
(
hidden_states
,
prefix
=
f
"
{
self
.
cache_prefix
}
_
{
block_id
}
_Bn_hidden_states"
,
)
if
encoder_hidden_states_residual
is
not
None
:
if
self
.
context_manager
.
is_encoder_cache_residual
():
self
.
context_manager
.
set_Bn_encoder_buffer
(
encoder_hidden_states_residual
,
prefix
=
f
"
{
self
.
cache_prefix
}
_
{
block_id
}
_Bn_encoder_residual"
,
)
else
:
self
.
context_manager
.
set_Bn_encoder_buffer
(
encoder_hidden_states_residual
,
prefix
=
f
"
{
self
.
cache_prefix
}
_
{
block_id
}
_Bn_encoder_hidden_states"
,
)
torch
.
_dynamo
.
graph_break
()
return
hidden_states
,
encoder_hidden_states
class
BagelCachedAdapter
(
CachedAdapter
):
"""
Custom CachedAdapter for Bagel that uses BagelCachedContextManager and BagelCachedBlocks.
"""
@
classmethod
def
create_context
(
cls
,
block_adapter
:
BlockAdapter
,
**
context_kwargs
,
)
->
tuple
[
list
[
str
],
list
[
dict
[
str
,
Any
]]]:
# Override to use BagelCachedContextManager
BlockAdapter
.
assert_normalized
(
block_adapter
)
if
BlockAdapter
.
is_cached
(
block_adapter
.
pipe
):
return
block_adapter
.
pipe
# Check context_kwargs
context_kwargs
=
cls
.
check_context_kwargs
(
block_adapter
,
**
context_kwargs
)
# Each Pipeline should have it's own context manager instance.
cache_config
:
BasicCacheConfig
=
context_kwargs
.
get
(
"cache_config"
,
None
)
assert
cache_config
is
not
None
,
"cache_config can not be None."
# Apply cache on pipeline: wrap cache context
pipe_cls_name
=
block_adapter
.
pipe
.
__class__
.
__name__
# USE CUSTOM CONTEXT MANAGER
context_manager
=
BagelCachedContextManager
(
name
=
f
"
{
pipe_cls_name
}
_
{
hash
(
id
(
block_adapter
.
pipe
))
}
"
,
persistent_context
=
isinstance
(
block_adapter
.
pipe
,
FakeDiffusionPipeline
),
)
flatten_contexts
,
contexts_kwargs
=
cls
.
modify_context_params
(
block_adapter
,
**
context_kwargs
)
block_adapter
.
pipe
.
_context_manager
=
context_manager
# instance level
if
not
context_manager
.
persistent_context
:
original_call
=
block_adapter
.
pipe
.
__class__
.
__call__
@
functools
.
wraps
(
original_call
)
def
new_call
(
self
,
*
args
,
**
kwargs
):
with
ExitStack
()
as
stack
:
# cache context will be reset for each pipe inference
for
context_name
,
context_kwargs
in
zip
(
flatten_contexts
,
contexts_kwargs
):
stack
.
enter_context
(
context_manager
.
enter_context
(
context_manager
.
reset_context
(
context_name
,
**
context_kwargs
,
),
)
)
outputs
=
original_call
(
self
,
*
args
,
**
kwargs
)
cls
.
apply_stats_hooks
(
block_adapter
)
return
outputs
block_adapter
.
pipe
.
__class__
.
__call__
=
new_call
block_adapter
.
pipe
.
__class__
.
_original_call
=
original_call
else
:
# Init persistent cache context for transformer
for
context_name
,
context_kwargs
in
zip
(
flatten_contexts
,
contexts_kwargs
):
context_manager
.
reset_context
(
context_name
,
**
context_kwargs
,
)
block_adapter
.
pipe
.
__class__
.
_is_cached
=
True
cls
.
apply_params_hooks
(
block_adapter
,
contexts_kwargs
)
return
flatten_contexts
,
contexts_kwargs
@
classmethod
def
collect_unified_blocks
(
cls
,
block_adapter
:
BlockAdapter
,
contexts_kwargs
:
list
[
dict
],
)
->
list
[
dict
[
str
,
torch
.
nn
.
ModuleList
]]:
# Override to use BagelCachedBlocks
BlockAdapter
.
assert_normalized
(
block_adapter
)
total_cached_blocks
:
list
[
dict
[
str
,
torch
.
nn
.
ModuleList
]]
=
[]
assert
hasattr
(
block_adapter
.
pipe
,
"_context_manager"
)
# Skipping isinstance check for ContextManager._supported_managers to avoid import issues
for
i
in
range
(
len
(
block_adapter
.
transformer
)):
unified_blocks_bind_context
=
{}
for
j
in
range
(
len
(
block_adapter
.
blocks
[
i
])):
cache_config
:
BasicCacheConfig
=
contexts_kwargs
[
i
*
len
(
block_adapter
.
blocks
[
i
])
+
j
][
"cache_config"
]
# Directly instantiate BagelCachedBlocks
unified_blocks_bind_context
[
block_adapter
.
unique_blocks_name
[
i
][
j
]]
=
torch
.
nn
.
ModuleList
(
[
BagelCachedBlocks
(
# 0. Transformer blocks configuration
block_adapter
.
blocks
[
i
][
j
],
transformer
=
block_adapter
.
transformer
[
i
],
forward_pattern
=
block_adapter
.
forward_pattern
[
i
][
j
],
check_forward_pattern
=
block_adapter
.
check_forward_pattern
,
check_num_outputs
=
block_adapter
.
check_num_outputs
,
# 1. Cache/Prune context configuration
cache_prefix
=
block_adapter
.
blocks_name
[
i
][
j
],
cache_context
=
block_adapter
.
unique_blocks_name
[
i
][
j
],
context_manager
=
block_adapter
.
pipe
.
_context_manager
,
cache_type
=
cache_config
.
cache_type
,
)
]
)
total_cached_blocks
.
append
(
unified_blocks_bind_context
)
return
total_cached_blocks
def
enable_cache_for_bagel
(
pipeline
:
Any
,
cache_config
:
Any
)
->
Callable
[[
int
],
None
]:
"""Enable cache-dit for Bagel model (via OmniDiffusion pipeline).
Args:
pipeline: The OmniDiffusion pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called to update cache context with new num_inference_steps.
"""
# Build DBCacheConfig
db_cache_config
=
_build_db_cache_config
(
cache_config
)
# Build calibrator config if TaylorSeer is enabled
calibrator_config
=
None
if
cache_config
.
enable_taylorseer
:
taylorseer_order
=
cache_config
.
taylorseer_order
calibrator_config
=
TaylorSeerCalibratorConfig
(
taylorseer_order
=
taylorseer_order
)
logger
.
info
(
f
"TaylorSeer enabled with order=
{
taylorseer_order
}
"
)
# Access the transformer: BagelPipeline -> Qwen2MoTForCausalLM -> Qwen2MoTModel
# BagelPipeline has self.language_model which is Qwen2MoTForCausalLM
# Qwen2MoTForCausalLM has self.model which is Qwen2MoTModel
transformer
=
pipeline
.
language_model
.
model
logger
.
info
(
f
"Enabling cache-dit on Bagel transformer: "
f
"Fn=
{
db_cache_config
.
Fn_compute_blocks
}
, "
f
"Bn=
{
db_cache_config
.
Bn_compute_blocks
}
, "
f
"W=
{
db_cache_config
.
max_warmup_steps
}
, "
)
# Enable cache-dit on the transformer
# Pattern_0 corresponds to (hidden_states, encoder_hidden_states) input, output
# Custom adapter for Bagel to handle NaiveCache correctly
# from vllm_omni.diffusion.cache.bagel_cache_adapter import BagelCachedAdapter # No longer needed
BagelCachedAdapter
.
apply
(
BlockAdapter
(
transformer
=
transformer
,
blocks
=
transformer
.
layers
,
forward_pattern
=
ForwardPattern
.
Pattern_0
,
),
cache_config
=
db_cache_config
,
calibrator_config
=
calibrator_config
,
)
def
refresh_cache_context
(
pipeline
:
Any
,
num_inference_steps
:
int
,
verbose
:
bool
=
True
)
->
None
:
transformer
=
pipeline
.
language_model
.
model
if
cache_config
.
scm_steps_mask_policy
is
None
:
cache_dit
.
refresh_context
(
transformer
,
num_inference_steps
=
num_inference_steps
,
verbose
=
verbose
)
else
:
cache_dit
.
refresh_context
(
transformer
,
cache_config
=
DBCacheConfig
().
reset
(
num_inference_steps
=
num_inference_steps
,
steps_computation_mask
=
cache_dit
.
steps_mask
(
mask_policy
=
cache_config
.
scm_steps_mask_policy
,
total_steps
=
num_inference_steps
,
),
steps_computation_policy
=
cache_config
.
scm_steps_policy
,
),
verbose
=
verbose
,
)
return
refresh_cache_context
# Register custom cache-dit enablers after function definitions
CUSTOM_DIT_ENABLERS
.
update
(
{
"Wan22Pipeline"
:
enable_cache_for_wan22
,
"Wan22I2VPipeline"
:
enable_cache_for_wan22
,
"Wan22TI2VPipeline"
:
enable_cache_for_wan22
,
"FluxPipeline"
:
enable_cache_for_flux
,
"LongCatImagePipeline"
:
enable_cache_for_longcat_image
,
"LongCatImageEditPipeline"
:
enable_cache_for_longcat_image
,
"StableDiffusion3Pipeline"
:
enable_cache_for_sd3
,
"BagelPipeline"
:
enable_cache_for_bagel
,
}
)
class
CacheDiTBackend
(
CacheBackend
):
"""Backend class for cache-dit acceleration on diffusion pipelines.
This class implements cache-dit acceleration (DBCache, SCM, TaylorSeer) using
the cache-dit library. It inherits from CacheBackend and provides a unified
interface for managing cache-dit acceleration on diffusion models.
Attributes:
config: Cache configuration (DiffusionCacheConfig instance), inherited from CacheBackend.
enabled: Whether cache-dit is enabled on this pipeline, inherited from CacheBackend.
_refresh_func: Internal refresh function for updating cache context.
_last_num_inference_steps: Last num_inference_steps used for refresh optimization.
"""
def
__init__
(
self
,
cache_config
:
Any
=
None
):
"""Initialize the cache-dit backend.
Args:
cache_config: Cache configuration (DiffusionCacheConfig instance, dict, or None).
If None or empty, uses default DiffusionCacheConfig().
"""
# Use default config if cache_config is not provided or is empty
if
cache_config
is
None
:
config
=
DiffusionCacheConfig
()
elif
isinstance
(
cache_config
,
dict
):
# Convert dict to DiffusionCacheConfig, using defaults for missing keys
config
=
DiffusionCacheConfig
.
from_dict
(
cache_config
)
else
:
config
=
cache_config
# Initialize base class with normalized config
super
().
__init__
(
config
)
# Cache-dit specific attributes
self
.
_refresh_func
:
Callable
[[
Any
,
int
,
bool
],
None
]
|
None
=
None
self
.
_last_num_inference_steps
:
int
|
None
=
None
def
enable
(
self
,
pipeline
:
Any
)
->
None
:
"""Enable cache-dit on the pipeline if configured.
This method applies cache-dit acceleration to the appropriate transformer(s)
in the pipeline. It handles both single-transformer and dual-transformer
architectures (e.g., Wan2.2).
Args:
pipeline: The diffusion pipeline instance.
"""
# Extract pipeline name from pipeline
pipeline_name
=
pipeline
.
__class__
.
__name__
# Check if this model has a custom cache-dit enabler
if
pipeline_name
in
CUSTOM_DIT_ENABLERS
:
logger
.
info
(
f
"Using custom cache-dit enabler for model:
{
pipeline_name
}
"
)
self
.
_refresh_func
=
CUSTOM_DIT_ENABLERS
[
pipeline_name
](
pipeline
,
self
.
config
)
else
:
# For regular single-transformer models
self
.
_refresh_func
=
enable_cache_for_dit
(
pipeline
,
self
.
config
)
self
.
enabled
=
True
logger
.
info
(
f
"Cache-dit enabled successfully on
{
pipeline_name
}
"
)
def
refresh
(
self
,
pipeline
:
Any
,
num_inference_steps
:
int
,
verbose
:
bool
=
True
)
->
None
:
"""Refresh cache context with new num_inference_steps.
This method updates the cache context when num_inference_steps changes
during inference. For dual-transformer models (e.g., Wan2.2), it automatically
splits the steps based on boundary_ratio.
Args:
pipeline: The diffusion pipeline instance.
num_inference_steps: New number of inference steps.
verbose: Whether to log refresh operations.
"""
if
not
self
.
enabled
or
self
.
_refresh_func
is
None
:
logger
.
warning
(
"Cache-dit is not enabled. Cannot refresh cache context."
)
return
# Only refresh if num_inference_steps has changed
if
self
.
_last_num_inference_steps
is
None
or
num_inference_steps
!=
self
.
_last_num_inference_steps
:
if
verbose
:
logger
.
info
(
f
"Refreshing cache context for transformer with num_inference_steps:
{
num_inference_steps
}
"
)
self
.
_refresh_func
(
pipeline
,
num_inference_steps
,
verbose
)
self
.
_last_num_inference_steps
=
num_inference_steps
def
is_enabled
(
self
)
->
bool
:
"""Check if cache-dit is enabled on this pipeline.
Returns:
True if cache-dit is enabled, False otherwise.
"""
return
self
.
enabled
def
may_enable_cache_dit
(
pipeline
:
Any
,
od_config
:
OmniDiffusionConfig
)
->
Optional
[
"CacheDiTBackend"
]:
"""Enable cache-dit on the pipeline if configured (convenience function).
This is a convenience function that creates and enables a CacheDiTBackend.
For new code, consider using CacheDiTBackend directly.
Args:
pipeline: The diffusion pipeline instance.
od_config: OmniDiffusionConfig with cache configuration.
Returns:
A CacheDiTBackend instance if cache-dit is enabled, None otherwise.
"""
if
od_config
.
cache_backend
!=
"cache-dit"
or
not
od_config
.
cache_config
:
return
None
backend
=
CacheDiTBackend
(
od_config
.
cache_config
)
backend
.
enable
(
pipeline
)
return
backend
if
backend
.
is_enabled
()
else
None
vllm_omni/diffusion/cache/selector.py
0 → 100644
View file @
c1cacde6
from
typing
import
Any
from
vllm_omni.diffusion.cache.base
import
CacheBackend
from
vllm_omni.diffusion.cache.cache_dit_backend
import
CacheDiTBackend
from
vllm_omni.diffusion.cache.teacache.backend
import
TeaCacheBackend
from
vllm_omni.diffusion.data
import
DiffusionCacheConfig
def
get_cache_backend
(
cache_backend
:
str
|
None
,
cache_config
:
Any
)
->
CacheBackend
|
None
:
"""Get cache backend instance based on cache_backend string.
This is a selector function that routes to the appropriate backend implementation.
- cache_dit: Uses CacheDiTBackend with enable()/refresh() interface
- tea_cache: Uses TeaCacheBackend with enable()/refresh() interface
Args:
cache_backend: Cache backend name ("cache_dit", "tea_cache", or None).
cache_config: Cache configuration (dict or DiffusionCacheConfig instance).
Returns:
Cache backend instance (CacheDiTBackend or TeaCacheBackend) if cache_backend is set,
None otherwise.
Raises:
ValueError: If cache_backend is unsupported.
"""
if
cache_backend
is
None
or
cache_backend
==
"none"
:
return
None
if
isinstance
(
cache_config
,
dict
):
cache_config
=
DiffusionCacheConfig
.
from_dict
(
cache_config
)
if
cache_backend
==
"cache_dit"
:
return
CacheDiTBackend
(
cache_config
)
elif
cache_backend
==
"tea_cache"
:
return
TeaCacheBackend
(
cache_config
)
else
:
raise
ValueError
(
f
"Unsupported cache backend:
{
cache_backend
}
. Supported: 'cache_dit', 'tea_cache'"
)
vllm_omni/diffusion/cache/teacache/__init__.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
TeaCache: Timestep Embedding Aware Cache for diffusion model acceleration.
TeaCache speeds up diffusion inference by reusing transformer block computations
when consecutive timestep embeddings are similar.
This implementation uses a hooks-based approach that requires zero changes to
model code. Model developers only need to add an extractor function to support
new models.
Usage:
from vllm_omni import Omni
omni = Omni(
model="Qwen/Qwen-Image",
cache_backend="tea_cache",
cache_config={"rel_l1_thresh": 0.2}
)
images = omni.generate("a cat")
# Alternative: Using environment variable
# export DIFFUSION_CACHE_BACKEND=tea_cache
"""
from
vllm_omni.diffusion.cache.teacache.backend
import
TeaCacheBackend
from
vllm_omni.diffusion.cache.teacache.config
import
TeaCacheConfig
from
vllm_omni.diffusion.cache.teacache.extractors
import
(
CacheContext
,
register_extractor
,
)
from
vllm_omni.diffusion.cache.teacache.hook
import
TeaCacheHook
,
apply_teacache_hook
from
vllm_omni.diffusion.cache.teacache.state
import
TeaCacheState
__all__
=
[
"TeaCacheBackend"
,
"TeaCacheConfig"
,
"TeaCacheState"
,
"TeaCacheHook"
,
"apply_teacache_hook"
,
"register_extractor"
,
"CacheContext"
,
]
vllm_omni/diffusion/cache/teacache/backend.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
TeaCache backend implementation.
This module provides the TeaCache backend that implements the CacheBackend
interface using the hooks-based TeaCache system.
"""
from
typing
import
Any
from
vllm.logger
import
init_logger
from
vllm_omni.diffusion.cache.base
import
CacheBackend
from
vllm_omni.diffusion.cache.teacache.config
import
TeaCacheConfig
from
vllm_omni.diffusion.cache.teacache.hook
import
TeaCacheHook
,
apply_teacache_hook
from
vllm_omni.diffusion.data
import
DiffusionCacheConfig
logger
=
init_logger
(
__name__
)
def
enable_bagel_teacache
(
pipeline
:
Any
,
config
:
DiffusionCacheConfig
)
->
None
:
"""
Enable TeaCache for Bagel model.
"""
teacache_config
=
TeaCacheConfig
(
transformer_type
=
"Bagel"
,
rel_l1_thresh
=
config
.
rel_l1_thresh
,
coefficients
=
config
.
coefficients
,
)
transformer
=
pipeline
.
bagel
original_forward_flow
=
transformer
.
_forward_flow
import
types
def
forward_alias
(
self
,
*
args
,
**
kwargs
):
return
original_forward_flow
(
*
args
,
**
kwargs
)
transformer
.
forward
=
types
.
MethodType
(
forward_alias
,
transformer
)
apply_teacache_hook
(
transformer
,
teacache_config
)
transformer
.
_forward_flow
=
transformer
.
forward
pipeline
.
transformer
=
transformer
logger
.
info
(
f
"TeaCache applied with rel_l1_thresh=
{
teacache_config
.
rel_l1_thresh
}
, "
f
"transformer_class=
{
teacache_config
.
transformer_type
}
"
)
CUSTOM_TEACACHE_ENABLERS
=
{
"BagelPipeline"
:
enable_bagel_teacache
}
class
TeaCacheBackend
(
CacheBackend
):
"""
TeaCache implementation using hooks.
TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique
that speeds up diffusion inference by reusing transformer block computations
when consecutive timestep embeddings are similar.
The backend applies TeaCache hooks to the transformer which intercept the
forward pass and implement the caching logic transparently.
Example:
>>> from vllm_omni.diffusion.data import DiffusionCacheConfig
>>> backend = TeaCacheBackend(DiffusionCacheConfig(rel_l1_thresh=0.2))
>>> backend.enable(pipeline)
>>> # Generate with cache enabled
>>> backend.refresh(pipeline, num_inference_steps=50) # Refresh before each generation
>>> # Access config attributes: backend.config.rel_l1_thresh
"""
def
enable
(
self
,
pipeline
:
Any
)
->
None
:
"""
Enable TeaCache on transformer using hooks.
This creates a TeaCacheConfig from the backend's DiffusionCacheConfig
and applies the TeaCache hook to the transformer.
Args:
pipeline: Diffusion pipeline instance. Extracts transformer and transformer_type:
- transformer: pipeline.transformer
- transformer_type: pipeline.transformer.__class__.__name__
"""
# Helper to get pipeline class name
pipeline_type
=
pipeline
.
__class__
.
__name__
# Check for pipeline-level custom enablers
if
pipeline_type
in
CUSTOM_TEACACHE_ENABLERS
:
logger
.
info
(
f
"Using custom TeaCache enabler for model:
{
pipeline_type
}
"
)
CUSTOM_TEACACHE_ENABLERS
[
pipeline_type
](
pipeline
,
self
.
config
)
else
:
transformer
=
pipeline
.
transformer
transformer_type
=
transformer
.
__class__
.
__name__
# Create TeaCacheConfig from DiffusionCacheConfig with transformer_type
# Access parameters via attribute access: config.rel_l1_thresh
# rel_l1_thresh already has a default value of 0.2 in DiffusionCacheConfig
try
:
teacache_config
=
TeaCacheConfig
(
transformer_type
=
transformer_type
,
rel_l1_thresh
=
self
.
config
.
rel_l1_thresh
,
coefficients
=
self
.
config
.
coefficients
,
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to create TeaCacheConfig:
{
e
}
"
)
raise
ValueError
(
f
"Invalid TeaCache configuration:
{
e
}
. "
f
"Expected keys: rel_l1_thresh, coefficients (optional). "
f
"transformer_type is automatically extracted from pipeline.transformer.__class__.__name__."
)
# Apply hook to transformer
apply_teacache_hook
(
transformer
,
teacache_config
)
logger
.
info
(
f
"TeaCache applied with rel_l1_thresh=
{
teacache_config
.
rel_l1_thresh
}
, "
f
"transformer_class=
{
teacache_config
.
transformer_type
}
"
)
# Mark as enabled
self
.
enabled
=
True
def
refresh
(
self
,
pipeline
:
Any
,
num_inference_steps
:
int
,
verbose
:
bool
=
True
)
->
None
:
"""
Refresh TeaCache state for new generation.
Clears all cached residuals and resets counters/accumulators.
Should be called before each generation to ensure clean state.
Args:
pipeline: Diffusion pipeline instance. Extracts transformer via pipeline.transformer.
num_inference_steps: Number of inference steps for the current generation.
Currently not used by TeaCache but accepted for interface consistency.
verbose: Whether to log refresh operations (default: True)
"""
# Extract transformer from pipeline
transformer
=
pipeline
.
transformer
if
hasattr
(
transformer
,
"_hook_registry"
):
hook
=
transformer
.
_hook_registry
.
get_hook
(
TeaCacheHook
.
_HOOK_NAME
)
if
hook
is
not
None
:
transformer
.
_hook_registry
.
reset_hook
(
TeaCacheHook
.
_HOOK_NAME
)
if
verbose
:
logger
.
debug
(
f
"TeaCache state refreshed (num_inference_steps=
{
num_inference_steps
}
)"
)
else
:
if
verbose
:
logger
.
warning
(
"TeaCache hook not found, nothing to refresh"
)
else
:
if
verbose
:
logger
.
warning
(
"Transformer has no hook registry, TeaCache may not be applied"
)
vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
0 → 100644
View file @
c1cacde6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
types
from
typing
import
Any
import
numpy
as
np
import
torch
from
vllm.config
import
LoadConfig
from
vllm_omni.diffusion.cache.teacache.extractors
import
get_extractor
from
vllm_omni.diffusion.data
import
OmniDiffusionConfig
from
vllm_omni.diffusion.hooks
import
HookRegistry
,
ModelHook
from
vllm_omni.diffusion.model_loader.diffusers_loader
import
DiffusersPipelineLoader
from
vllm_omni.diffusion.models.bagel.pipeline_bagel
import
BagelPipeline
class
DataCollectionHook
(
ModelHook
):
"""Hook to collect modulated inputs and model outputs for TeaCache coefficient estimation."""
_HOOK_NAME
=
"teacache_collector"
def
__init__
(
self
,
transformer_type
:
str
):
super
().
__init__
()
self
.
transformer_type
=
transformer_type
self
.
extractor_fn
=
None
self
.
current_trajectory
:
list
[
tuple
[
np
.
ndarray
,
np
.
ndarray
]]
=
[]
def
initialize_hook
(
self
,
module
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
self
.
extractor_fn
=
get_extractor
(
self
.
transformer_type
)
return
module
def
new_forward
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
ctx
=
self
.
extractor_fn
(
module
,
*
args
,
**
kwargs
)
modulated_input_cpu
=
ctx
.
modulated_input
.
detach
().
cpu
().
numpy
()
outputs
=
ctx
.
run_transformer_blocks
()
ctx
.
hidden_states
=
outputs
[
0
]
if
len
(
outputs
)
>
1
and
ctx
.
encoder_hidden_states
is
not
None
:
ctx
.
encoder_hidden_states
=
outputs
[
1
]
model_output_cpu
=
ctx
.
hidden_states
.
detach
().
cpu
().
numpy
()
self
.
current_trajectory
.
append
((
modulated_input_cpu
,
model_output_cpu
))
return
ctx
.
postprocess
(
ctx
.
hidden_states
)
def
start_collection
(
self
):
self
.
current_trajectory
=
[]
def
stop_collection
(
self
)
->
list
[
tuple
[
np
.
ndarray
,
np
.
ndarray
]]:
return
list
(
self
.
current_trajectory
)
class
BagelAdapter
:
"""Adapter for Bagel model."""
@
staticmethod
def
load_pipeline
(
model_path
:
str
,
device
:
str
=
"cuda"
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
)
->
BagelPipeline
:
od_config
=
OmniDiffusionConfig
.
from_kwargs
(
model
=
model_path
,
dtype
=
dtype
)
od_config
.
model_class_name
=
"BagelPipeline"
pipeline
=
BagelPipeline
(
od_config
=
od_config
)
loader
=
DiffusersPipelineLoader
(
LoadConfig
())
loader
.
load_weights
(
pipeline
)
pipeline
.
to
(
device
)
return
pipeline
@
staticmethod
def
get_transformer
(
pipeline
:
Any
)
->
tuple
[
Any
,
str
]:
return
pipeline
.
bagel
,
"Bagel"
@
staticmethod
def
install_hook
(
transformer
:
Any
,
hook
:
DataCollectionHook
)
->
None
:
original_forward_flow
=
transformer
.
_forward_flow
def
forward_alias
(
self
,
*
args
,
**
kwargs
):
return
original_forward_flow
(
*
args
,
**
kwargs
)
transformer
.
forward
=
types
.
MethodType
(
forward_alias
,
transformer
)
registry
=
HookRegistry
.
get_or_create
(
transformer
)
registry
.
register_hook
(
hook
.
_HOOK_NAME
,
hook
)
transformer
.
_forward_flow
=
transformer
.
forward
class
DefaultAdapter
:
"""Default adapter for standard diffusers pipelines."""
@
staticmethod
def
load_pipeline
(
model_path
:
str
,
device
:
str
,
dtype
:
torch
.
dtype
)
->
Any
:
raise
NotImplementedError
(
"DefaultAdapter.load_pipeline not implemented"
)
@
staticmethod
def
get_transformer
(
pipeline
:
Any
)
->
tuple
[
Any
,
str
]:
return
pipeline
.
transformer
,
pipeline
.
transformer
.
__class__
.
__name__
@
staticmethod
def
install_hook
(
transformer
:
Any
,
hook
:
DataCollectionHook
)
->
None
:
registry
=
HookRegistry
.
get_or_create
(
transformer
)
registry
.
register_hook
(
hook
.
_HOOK_NAME
,
hook
)
_MODEL_ADAPTERS
:
dict
[
str
,
type
]
=
{
"Bagel"
:
BagelAdapter
,
}
_EPSILON
=
1e-6
def
calculate_relative_l1
(
tensor_current
:
np
.
ndarray
,
tensor_next
:
np
.
ndarray
)
->
float
:
"""Calculate relative L1 distance (Eq. 4 from TeaCache paper)."""
diff
=
np
.
abs
(
tensor_current
-
tensor_next
).
sum
()
norm
=
np
.
abs
(
tensor_current
).
sum
()
+
_EPSILON
return
diff
/
norm
def
estimate_teacache_coefficients
(
collected_data
:
list
[
list
[
tuple
[
np
.
ndarray
,
np
.
ndarray
]]],
poly_order
:
int
=
4
)
->
list
[
float
]:
"""Estimate polynomial coefficients for TeaCache using np.polyfit."""
input_diffs
,
output_diffs
=
[],
[]
for
sample
in
collected_data
:
for
t
in
range
(
len
(
sample
)
-
1
):
feat_in_curr
,
feat_out_curr
=
sample
[
t
]
feat_in_next
,
feat_out_next
=
sample
[
t
+
1
]
input_diffs
.
append
(
calculate_relative_l1
(
feat_in_curr
,
feat_in_next
))
output_diffs
.
append
(
calculate_relative_l1
(
feat_out_curr
,
feat_out_next
))
x
=
np
.
array
(
input_diffs
,
dtype
=
np
.
float64
)
y
=
np
.
array
(
output_diffs
,
dtype
=
np
.
float64
)
print
(
"Data statistics:"
)
print
(
f
" Count:
{
len
(
x
)
}
"
)
print
(
f
" Input Diffs (x): min=
{
x
.
min
():.
4
e
}
, max=
{
x
.
max
():.
4
e
}
, mean=
{
x
.
mean
():.
4
e
}
"
)
print
(
f
" Output Diffs (y): min=
{
y
.
min
():.
4
e
}
, max=
{
y
.
max
():.
4
e
}
, mean=
{
y
.
mean
():.
4
e
}
"
)
return
np
.
polyfit
(
x
,
y
,
poly_order
).
tolist
()
class
TeaCacheCoefficientEstimator
:
"""Model-agnostic helper class to collect data and estimate TeaCache coefficients."""
def
__init__
(
self
,
model_path
:
str
,
model_type
:
str
=
"Bagel"
,
device
:
str
=
"cuda"
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
):
# Add validation here ⬇️
if
model_type
not
in
_MODEL_ADAPTERS
:
available_types
=
list
(
_MODEL_ADAPTERS
.
keys
())
raise
ValueError
(
f
"Unsupported model_type: '
{
model_type
}
'. "
f
"Available types:
{
available_types
}
. "
f
"To add support for a new model, add an entry to _MODEL_ADAPTERS."
)
adapter
=
_MODEL_ADAPTERS
.
get
(
model_type
,
DefaultAdapter
)
self
.
pipeline
=
adapter
.
load_pipeline
(
model_path
,
device
,
dtype
)
self
.
transformer
,
self
.
transformer_type
=
adapter
.
get_transformer
(
self
.
pipeline
)
self
.
hook
=
DataCollectionHook
(
self
.
transformer_type
)
self
.
collected_data
:
list
[
list
[
tuple
[
np
.
ndarray
,
np
.
ndarray
]]]
=
[]
adapter
.
install_hook
(
self
.
transformer
,
self
.
hook
)
def
collect_from_prompt
(
self
,
prompt
:
str
,
**
generate_kwargs
):
self
.
hook
.
start_collection
()
from
vllm_omni.diffusion.request
import
OmniDiffusionRequest
req
=
OmniDiffusionRequest
(
prompt
=
prompt
,
num_inference_steps
=
generate_kwargs
.
get
(
"num_inference_steps"
,
20
),
seed
=
generate_kwargs
.
get
(
"seed"
,
42
),
)
self
.
pipeline
.
forward
(
req
)
trajectory
=
self
.
hook
.
stop_collection
()
if
trajectory
:
self
.
collected_data
.
append
(
trajectory
)
def
estimate
(
self
,
poly_order
:
int
=
4
)
->
list
[
float
]:
"""Estimate polynomial coefficients from collected data.
Args:
poly_order: Order of polynomial fit (default: 4)
Returns:
List of polynomial coefficients [a_n, a_{n-1}, ..., a_1, a_0]
Raises:
RuntimeError: If no data has been collected
"""
if
not
self
.
collected_data
:
raise
RuntimeError
(
"No data collected for coefficient estimation. "
"Call collect_from_prompt() at least once before calling estimate()."
)
return
estimate_teacache_coefficients
(
self
.
collected_data
,
poly_order
)
Prev
1
…
11
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment