You need to sign in or sign up before continuing.
Commit c1cacde6 authored by weishb's avatar weishb
Browse files

vllm-omni_0.15.0.rc1+fix1 first commit

parent 35607782
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Jiarui Fang.
# Adapted from https://github.com/feifeibear/long-context-attention
import torch
from vllm_omni.diffusion.attention.backends.ring.ring_selector import AttnType, select_flash_attn_impl
from vllm_omni.diffusion.attention.backends.ring.ring_utils import update_out_and_lse
from vllm_omni.diffusion.distributed.comm import RingComm
def ring_flash_attn_forward(
process_group,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale,
dropout_p=0,
causal=True,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
attn_type: AttnType = AttnType.FA,
attn_processor=None,
joint_tensor_key=None,
joint_tensor_value=None,
joint_strategy="front",
):
# Validate causal + joint_strategy combination
# When causal=True and joint_strategy="rear", the causal mask would incorrectly
# prevent local query tokens from attending to joint key tokens (which are
# concatenated at the end). This breaks the semantics where joint tokens
# (e.g., text conditioning) should be visible to all local tokens.
if causal and joint_tensor_key is not None and joint_strategy == "rear":
raise ValueError(
"joint_strategy='rear' is not compatible with causal=True in Ring Attention. "
"When using causal attention with joint tokens, use joint_strategy='front' "
"to ensure joint tokens act as a visible prefix for all local tokens. "
"With 'rear' strategy, the causal mask would incorrectly block local tokens "
"from seeing the joint tokens."
)
comm = RingComm(process_group)
out = None
lse = None
next_k, next_v = None, None
# Check and adjust q, k, v to be contiguous
if not q.is_contiguous():
q = q.contiguous()
if not k.is_contiguous():
k = k.contiguous()
if not v.is_contiguous():
v = v.contiguous()
for step in range(comm.world_size):
if step + 1 != comm.world_size:
next_k: torch.Tensor
next_v: torch.Tensor
next_k = comm.send_recv(k)
next_v = comm.send_recv(v)
comm.commit()
if not causal or step <= comm.rank:
step_k = k
step_v = v
if step == 0 and joint_tensor_key is not None:
if joint_strategy == "front":
step_k = torch.cat([joint_tensor_key, step_k], dim=1)
step_v = torch.cat([joint_tensor_value, step_v], dim=1)
else:
step_k = torch.cat([step_k, joint_tensor_key], dim=1)
step_v = torch.cat([step_v, joint_tensor_value], dim=1)
fn = select_flash_attn_impl(attn_type, stage="fwd-only", attn_processor=attn_processor)
block_out, block_lse = fn(
q,
step_k,
step_v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal and step == 0,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
)
# Ensure block_out is contiguous if needed, though usually it is from FA
if attn_type == AttnType.SPARSE_SAGE:
out, lse = block_out, block_lse
else:
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
if step + 1 != comm.world_size:
comm.wait()
k = next_k
v = next_v
out = out.to(q.dtype)
if attn_type != AttnType.SPARSE_SAGE:
lse = lse.squeeze(dim=-1).transpose(1, 2)
return out, lse
class RingFlashAttnFunc(torch.autograd.Function):
"""Ring Flash Attention autograd function (inference only, no backward)."""
@staticmethod
def forward(
ctx,
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
group,
attn_type,
attn_processor,
joint_tensor_key=None,
joint_tensor_value=None,
joint_strategy="front",
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
assert alibi_slopes is None
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out, softmax_lse = ring_flash_attn_forward(
group,
q,
k,
v,
softmax_scale=softmax_scale,
dropout_p=dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=False,
attn_type=attn_type,
attn_processor=attn_processor,
joint_tensor_key=joint_tensor_key,
joint_tensor_value=joint_tensor_value,
joint_strategy=joint_strategy,
)
return out if not return_softmax else (out, softmax_lse, None)
def ring_flash_attn_qkvpacked_func(
qkv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
group=None,
attn_type: AttnType = AttnType.FA,
):
return RingFlashAttnFunc.apply(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
dropout_p,
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
group,
attn_type,
None, # attn_processor
None, # joint_tensor_key
None, # joint_tensor_value
"front", # joint_strategy
)
def ring_flash_attn_kvpacked_func(
q,
kv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
group=None,
attn_type: AttnType = AttnType.FA,
):
return RingFlashAttnFunc.apply(
q,
kv[:, :, 0],
kv[:, :, 1],
dropout_p,
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
group,
attn_type,
None, # attn_processor
None, # joint_tensor_key
None, # joint_tensor_value
"front", # joint_strategy
)
def ring_flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
group=None,
attn_type: AttnType = AttnType.FA,
attn_processor=None,
joint_tensor_key=None,
joint_tensor_value=None,
joint_strategy="front",
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, None]:
"""Ring Attention forward pass using Flash Attention backend.
Implements Ring Attention with sequence parallelism using a ring-based P2P
communication pattern. The sequence dimension is sharded across devices, and
Key/Value blocks are circulated through the ring to accumulate attention results.
Args:
q (torch.Tensor): Query tensor of shape (batch, seq_len, num_heads, head_dim).
Sequence dimension is sharded across the ring group.
k (torch.Tensor): Key tensor of shape (batch, seq_len, num_heads, head_dim).
Sequence dimension is sharded across the ring group.
v (torch.Tensor): Value tensor of shape (batch, seq_len, num_heads, head_dim).
Sequence dimension is sharded across the ring group.
dropout_p (float): Dropout probability. Defaults to 0.0.
softmax_scale (float | None): Scaling factor for softmax.
If None, computed as head_dim^(-0.5).
causal (bool): Whether to apply causal masking. Defaults to False.
window_size (tuple[int, int]): Sliding window size for attention.
(-1, -1) means no windowing.
softcap (float): Soft capping value for attention logits. Defaults to 0.0.
alibi_slopes (torch.Tensor | None): ALiBi slopes for positional bias.
Not supported.
deterministic (bool): Whether to use deterministic algorithms.
Defaults to False.
return_attn_probs (bool): If True, returns (out, softmax_lse, None).
Defaults to False.
group (ProcessGroup | None): Process group for ring communication.
Defaults to None.
attn_type (AttnType): Flash Attention implementation type
(AttnType.FA, AttnType.FA3, etc.).
attn_processor (Callable | None): Custom attention processor for sparse
attention. Defaults to None.
joint_tensor_key (torch.Tensor | None): Additional key tensor for joint
attention (e.g., text + image). Concatenated only at step=0.
Defaults to None.
joint_tensor_value (torch.Tensor | None): Additional value tensor for
joint attention (e.g., text + image). Concatenated only at step=0.
Defaults to None.
joint_strategy (str): Concatenation strategy ("front" or "back").
Defaults to "front".
Returns:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, None]]:
- If return_attn_probs is False: Output tensor (batch, seq_len, num_heads, head_dim).
- If return_attn_probs is True: A tuple (out, softmax_lse, None).
"""
return RingFlashAttnFunc.apply(
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
group,
attn_type,
attn_processor,
joint_tensor_key,
joint_tensor_value,
joint_strategy,
)
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Jiarui Fang.
# Adapted from https://github.com/feifeibear/long-context-attention
# adapted from https://github.com/huggingface/picotron/blob/main/picotron/context_parallel/context_parallel.py
# Copyright 2024 The HuggingFace Inc. team and Jiarui Fang.
import torch
from vllm.logger import init_logger
from vllm_omni.diffusion.attention.backends.ring.ring_kernels import pytorch_attn_forward
from vllm_omni.diffusion.attention.backends.ring.ring_utils import update_out_and_lse
from vllm_omni.diffusion.distributed.comm import RingComm
logger = init_logger(__name__)
def ring_pytorch_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
group=None,
op_type="efficient",
joint_tensor_key=None,
joint_tensor_value=None,
joint_strategy="front",
):
return RingAttentionFunc.apply(
group,
q,
k,
v,
softmax_scale,
causal,
op_type,
joint_tensor_key,
joint_tensor_value,
joint_strategy,
)
class RingAttentionFunc(torch.autograd.Function):
"""Ring Attention autograd function using PyTorch SDPA (inference only, no backward)."""
@staticmethod
def forward(
ctx,
group,
q,
k,
v,
sm_scale,
is_causal,
op_type,
joint_tensor_key=None,
joint_tensor_value=None,
joint_strategy="front",
):
# Validate causal + joint_strategy combination
# When causal=True and joint_strategy="rear", the causal mask would incorrectly
# prevent local query tokens from attending to joint key tokens (which are
# concatenated at the end). This breaks the semantics where joint tokens
# (e.g., text conditioning) should be visible to all local tokens.
if is_causal and joint_tensor_key is not None and joint_strategy == "rear":
raise ValueError(
"joint_strategy='rear' is not compatible with causal=True in Ring Attention. "
"When using causal attention with joint tokens, use joint_strategy='front' "
"to ensure joint tokens act as a visible prefix for all local tokens. "
"With 'rear' strategy, the causal mask would incorrectly block local tokens "
"from seeing the joint tokens."
)
comm = RingComm(group)
# Ensure tensors are contiguous for P2P communication
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out, lse = None, None
next_k, next_v = None, None
if sm_scale is None:
sm_scale = q.shape[-1] ** -0.5
for step in range(comm.world_size):
if step + 1 != comm.world_size:
next_k = comm.send_recv(k)
next_v = comm.send_recv(v)
comm.commit()
if not is_causal or step <= comm.rank:
step_k = k
step_v = v
if step == 0 and joint_tensor_key is not None:
if joint_strategy == "front":
step_k = torch.cat([joint_tensor_key, step_k], dim=1)
step_v = torch.cat([joint_tensor_value, step_v], dim=1)
else:
step_k = torch.cat([step_k, joint_tensor_key], dim=1)
step_v = torch.cat([step_v, joint_tensor_value], dim=1)
block_out, block_lse = pytorch_attn_forward(
q,
step_k,
step_v,
softmax_scale=sm_scale,
causal=is_causal and step == 0,
op_type=op_type,
)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
if step + 1 != comm.world_size:
comm.wait()
k = next_k
v = next_v
out = out.to(q.dtype)
return out
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
from vllm_omni.diffusion.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
)
logger = init_logger(__name__)
try:
from sageattention import sageattn
except ImportError:
logger.warning(
"SageAttentionBackend is not available. You may install sage-attention"
" by pip install git+https://github.com/thu-ml/SageAttention.git"
)
raise ImportError
# TODO add sage3 attention backend
class SageAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "SAGE_ATTN"
@staticmethod
def get_impl_cls() -> type["SageAttentionImpl"]:
return SageAttentionImpl
class SageAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
softmax_scale: float,
causal: bool = False,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.causal = causal
self.softmax_scale = softmax_scale
def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata = None,
) -> torch.Tensor:
output = sageattn(
query,
key,
value,
tensor_layout="NHD",
is_causal=self.causal,
sm_scale=self.softmax_scale,
)
return output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
from vllm_omni.diffusion.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
)
logger = init_logger(__name__)
def _maybe_reshape_attn_mask(query: torch.Tensor, key: torch.Tensor, attn_mask: torch.Tensor | None = None):
"""
Reshape Attention Mask
[batch_size, seq_len_k] -> [batch_size, 1, seq_len_q, seq_len_k]
"""
# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
if attn_mask is not None and torch.all(attn_mask != 0):
attn_mask = None
# Reshape Attention Mask
# [batch_size, seq_len_k] -> [batch_size, 1, seq_len_q, seq_len_k]
if (
attn_mask is not None
and attn_mask.ndim == 2
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[1] == key.shape[1]
):
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
attn_mask = attn_mask.to(torch.bool)
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
return attn_mask
class SDPABackend(AttentionBackend):
accept_output_buffer: bool = True
@classmethod
def supports_attention_mask(cls) -> bool:
return True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [x for x in range(1024)] # todo
@staticmethod
def get_name() -> str:
return "SDPA"
@staticmethod
def get_impl_cls() -> type["SDPAImpl"]:
return SDPAImpl
class SDPAImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
softmax_scale: float,
causal: bool = False,
num_kv_heads: int | None = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.causal = causal
self.softmax_scale = softmax_scale
def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
attention_mask = attn_metadata.attn_mask if attn_metadata else None
output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=self.causal,
scale=self.softmax_scale,
)
out = output.permute(0, 2, 1, 3)
return out
def forward_xpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor:
return self.forward_cuda(query, key, value, attn_metadata)
def forward_hip(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor:
return self.forward_cuda(query, key, value, attn_metadata)
def forward_npu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata | None = None,
) -> torch.Tensor:
if attn_metadata:
attention_mask = _maybe_reshape_attn_mask(query, key, attn_metadata.attn_mask)
setattr(attn_metadata, "attn_mask", attention_mask)
return self.forward_cuda(query, key, value, attn_metadata)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Utils for attention backends.
"""
from vllm_omni.diffusion.attention.backends.utils.fa import _pad_input, _unpad_input, _upad_input
__all__ = [
"_pad_input",
"_unpad_input",
"_upad_input",
]
# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py
import torch
import torch.nn.functional as F
from vllm_omni.platforms import current_omni_platform
# Flash Attention function detection with fallback chain
flash_attn_func = None
flash_attn_varlen_func = None
if current_omni_platform.is_rocm():
# ROCm: try Aiter first
try:
from vllm._aiter_ops import is_aiter_found_and_supported
if is_aiter_found_and_supported():
from aiter import flash_attn_func, flash_attn_varlen_func # noqa: F401
except (ImportError, ModuleNotFoundError):
pass
else:
# CUDA: try FA3 -> FA2 fallback chain
# Try FA3 from fa3-fwd PyPI package
try:
from fa3_fwd_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401
except (ImportError, ModuleNotFoundError):
pass
# Fallback: Try FA3 from flash-attention source build
if flash_attn_func is None:
try:
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401
except (ImportError, ModuleNotFoundError):
pass
# Fallback: Try FA2 from flash-attn package (try multiple import paths)
if flash_attn_func is None:
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func # noqa: F401
except (ImportError, ModuleNotFoundError):
pass
if flash_attn_func is None:
try:
from flash_attn.flash_attn_interface import ( # noqa: F401
flash_attn_func,
flash_attn_varlen_func,
)
except (ImportError, ModuleNotFoundError):
pass
# If no FA backend available, SDPA backend will be selected at the platform level
# flash_attn_func and flash_attn_varlen_func will be None
HAS_FLASH_ATTN = flash_attn_func is not None
def _index_first_axis(tensor, indices):
"""
A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
after flattening the first two dimensions of the tensor. This is functionally equivalent to
FA2's `index_first_axis` and replaces the need to import it.
"""
# The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
# two dimensions to get (total_tokens, ...) before indexing.
reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
return reshaped_tensor[indices]
def _unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
"""
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
_index_first_axis(hidden_states, indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)
def _pad_input(hidden_states, indices, batch, seqlen):
"""
pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[1:]
output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
output[indices] = hidden_states
return output.view(batch, seqlen, *dim)
def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
"""
Retrieves indexing data required to repad unpadded (ragged) tensors.
Arguments:
attention_mask (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
indices (`torch.Tensor`):
The indices of non-masked tokens from the flattened input sequence.
cu_seqlens (`torch.Tensor`):
The cumulative sequence lengths, used to index into ragged (unpadded) tensors.
`cu_seqlens` shape is (batch_size + 1,).
max_seqlen_in_batch (`int`):
Maximum sequence length in batch.
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
# NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile,
# this might cause a graph break
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def _upad_input(
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
unpad_input_func,
):
"""
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong
to different batches. This function is used instead of `flash_attn.bert_padding.unpad_input` in
order to avoid the recomputation of the same intermediary tensors for query, key, value tensors.
Arguments:
query_layer (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key_layer (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value_layer (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
attention_mask (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
query_length (`int`):
Target length.
unpad_input_func:
The function to use for unpadding the input tensors.
Return:
query_layer (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key_layer (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value_layer (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into
ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
`max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
if torch.compiler.is_compiling():
# allow PyTorch compiler to include operations that return scalar values (like .item()
torch._dynamo.config.capture_scalar_outputs = True
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
# With static caches, the k/v states may be larger than the mask ->
# we need to slice them to avoid generating garbage
# It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores
if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]):
key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :]
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = _index_first_axis(key_layer, indices_k)
value_layer = _index_first_axis(value_layer, indices_k)
if query_length == kv_seq_len:
query_layer = _index_first_axis(query_layer, indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
def _is_packed_sequence(position_ids, batch_size):
"""
Check the position ids whether packed sequences are indicated or not
1. Position ids exist
2. Flattened sequences only are supported
3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e.
we have multiple increasing sequences
"""
if position_ids is None:
return False
increasing_position_sequences = torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min()
return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) Microsoft Corporation and Jiarui Fang
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team & Jiarui Fang
# Adapted from
# https://github.com/feifeibear/long-context-attention/blob/main/yunchang/attention/layer.py
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.backends.sdpa import SDPABackend
from vllm_omni.diffusion.attention.parallel import build_parallel_attention_strategy
from vllm_omni.diffusion.attention.parallel.ring import RingParallelAttention
from vllm_omni.diffusion.attention.selector import get_attn_backend
from vllm_omni.diffusion.distributed.parallel_state import get_sp_group
from vllm_omni.diffusion.forward_context import get_forward_context
logger = init_logger(__name__)
class Attention(nn.Module):
def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: int | None = None,
prefix: str = "",
# ulysses attention
scatter_idx: int = 2,
gather_idx: int = 1,
use_sync: bool = False,
):
super().__init__()
self.attn_backend = get_attn_backend(-1)
self.attn_impl_cls = self.attn_backend.get_impl_cls()
self.attention = self.attn_impl_cls(
num_heads=num_heads,
head_size=head_size,
softmax_scale=softmax_scale,
causal=causal,
num_kv_heads=num_kv_heads,
)
# Instantiate fallback backend for float32 support
self.sdpa_fallback = SDPABackend.get_impl_cls()(
num_heads=num_heads,
head_size=head_size,
softmax_scale=softmax_scale,
causal=causal,
num_kv_heads=num_kv_heads,
)
self.backend_pref = None
self.softmax_scale = softmax_scale
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
self.use_sync = use_sync
self.causal = causal
self.use_ring = False
self.ring_pg = None
self.ring_runner = None
try:
config = get_forward_context().omni_diffusion_config
self.backend_pref = config.attention_backend
if config.parallel_config.ring_degree > 1:
self.use_ring = True
try:
sp_group = get_sp_group()
self.ring_pg = sp_group.ring_group
self.ring_runner = RingParallelAttention(sp_group)
except Exception:
self.use_ring = False
self.ring_runner = None
except Exception:
self.use_ring = False
self.ring_runner = None
self.parallel_strategy = build_parallel_attention_strategy(
scatter_idx=scatter_idx,
gather_idx=gather_idx,
use_sync=use_sync,
)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata = None,
) -> torch.Tensor:
# 1. Prepare inputs (Communication / Resharding)
# For Ulysses: AllToAll Q/K/V; Slicing joint_q/k/v
# For Ring: Concat joint_q
query, key, value, attn_metadata, ctx = self.parallel_strategy.pre_attention(query, key, value, attn_metadata)
# 2. Kernel Execution (Computation)
if self.use_ring:
out = self._run_ring_attention(query, key, value, attn_metadata)
else:
out = self._run_local_attention(query, key, value, attn_metadata)
# 3. Post-processing (Reverse Communication)
# For Ulysses: AllToAll Output, and AllGather Joint Output
out = self.parallel_strategy.post_attention(out, ctx)
return out
def _run_local_attention(self, query, key, value, attn_metadata):
if query.dtype == torch.float32:
logger.warning_once(
f"Only SDPA supports float32. Overriding user config {type(self.attention)} "
f"attention_backend='{self.backend_pref}' to 'sdpa' for dtype={query.dtype}."
)
return self.sdpa_fallback.forward(query, key, value, attn_metadata)
# Fallback to standard attention
return self.attention.forward(query, key, value, attn_metadata)
def _run_ring_attention(self, query, key, value, attn_metadata):
# Delegate to RingParallelAttention strategy if available
if self.ring_runner is not None:
return self.ring_runner.run_attention(
query, key, value, attn_metadata, softmax_scale=self.softmax_scale, causal=self.causal
)
raise RuntimeError("Ring attention is enabled but strategy is not RingParallelAttention")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Parallel attention strategies.
This package provides **communication / resharding strategies** for attention,
orthogonal to the **attention kernel backend** (SDPA/Flash/Sage).
The goal is to keep `vllm_omni.diffusion.attention.layer.Attention` small and
extensible: adding a new parallelism method should not require editing the core
Attention module, only adding a new strategy and selecting it in the factory.
"""
from .base import NoParallelAttention, ParallelAttentionContext, ParallelAttentionStrategy
from .factory import build_parallel_attention_strategy
__all__ = [
"ParallelAttentionStrategy",
"ParallelAttentionContext",
"NoParallelAttention",
"build_parallel_attention_strategy",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
import torch
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
@dataclass(frozen=True, slots=True)
class ParallelAttentionContext:
"""Opaque per-forward context returned by a parallel strategy.
Strategies may stash whatever they need here to finish post-processing after
the attention kernel runs (e.g. reverse resharding, slicing metadata, etc.).
"""
name: str
class ParallelAttentionStrategy(Protocol):
"""Pluggable strategy for parallel attention communication/resharding.
This is intentionally orthogonal to the attention *kernel* backend.
The kernel backend implements `AttentionImpl.forward()` for a given device,
while the parallel strategy implements how Q/K/V and outputs are sharded /
communicated across ranks.
"""
@property
def enabled(self) -> bool: ...
@property
def name(self) -> str: ...
def pre_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata | None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, AttentionMetadata | None, ParallelAttentionContext | None]:
"""Runs before the attention kernel.
Returns possibly transformed Q/K/V and metadata, and an optional context
for `post_attention`.
"""
def post_attention(
self,
attn_output: torch.Tensor,
ctx: ParallelAttentionContext | None,
) -> torch.Tensor:
"""Runs after the attention kernel."""
class NoParallelAttention:
"""Default strategy: do nothing (single device / no SP)."""
@property
def enabled(self) -> bool:
return False
@property
def name(self) -> str:
return "none"
def pre_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata | None,
):
return query, key, value, attn_metadata, None
def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor:
return attn_output
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from vllm.logger import init_logger
from vllm_omni.diffusion.attention.parallel.base import NoParallelAttention, ParallelAttentionStrategy
from vllm_omni.diffusion.attention.parallel.ring import RingParallelAttention
from vllm_omni.diffusion.attention.parallel.ulysses import UlyssesParallelAttention
from vllm_omni.diffusion.distributed.parallel_state import get_sequence_parallel_world_size, get_sp_group
from vllm_omni.diffusion.forward_context import get_forward_context
logger = init_logger(__name__)
def build_parallel_attention_strategy(
*,
scatter_idx: int,
gather_idx: int,
use_sync: bool,
) -> ParallelAttentionStrategy:
"""Select a parallel attention strategy based on current diffusion config.
Design principle:
- Attention kernel backend selection remains in `attention/selector.py`.
- Parallel attention selection is handled here, based on distributed config
and initialized process groups.
"""
try:
cfg = get_forward_context().omni_diffusion_config
p = cfg.parallel_config
except Exception as e:
logger.debug(f"No forward context available for parallel attention strategy: {e}")
return NoParallelAttention()
ulysses_degree = getattr(p, "ulysses_degree", 1)
ring_degree = getattr(p, "ring_degree", 1)
try:
sp_group = get_sp_group()
# Ensure SP group is initialized and world size > 1
if get_sequence_parallel_world_size() <= 1:
return NoParallelAttention()
except Exception as e:
# Log warning if SP is configured but group is not available
if ulysses_degree > 1 or ring_degree > 1:
logger.warning(
f"SP configured (ulysses={ulysses_degree}, ring={ring_degree}) but SP group not available: {e}. "
f"Falling back to NoParallelAttention. This may cause incorrect results."
)
return NoParallelAttention()
# Ulysses (or Hybrid Ulysses+Ring)
if ulysses_degree > 1:
logger.debug(f"Using UlyssesParallelAttention (ulysses_degree={ulysses_degree})")
return UlyssesParallelAttention(
sp_group=sp_group,
scatter_idx=scatter_idx,
gather_idx=gather_idx,
use_sync=use_sync,
)
# Pure Ring Attention
if ring_degree > 1:
logger.debug(f"Using RingParallelAttention (ring_degree={ring_degree})")
return RingParallelAttention(
sp_group=sp_group,
)
return NoParallelAttention()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
from vllm.logger import init_logger
# import torch.distributed as dist # Not used directly here, but good practice if needed
from vllm_omni.diffusion.attention.backends.ring.ring_globals import HAS_FA3, HAS_FLASH_ATTN
from vllm_omni.diffusion.attention.backends.ring.ring_selector import AttnType
from vllm_omni.diffusion.attention.parallel.base import (
ParallelAttentionContext,
# ParallelAttentionStrategy, # Not used in type hint below currently
)
from vllm_omni.diffusion.distributed.group_coordinator import SequenceParallelGroupCoordinator
# from vllm_omni.diffusion.attention.backends.ring_selector import AttnType # Already imported above
from vllm_omni.diffusion.forward_context import get_forward_context
if TYPE_CHECKING:
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
@dataclass(frozen=True, slots=True)
class _RingCtx(ParallelAttentionContext):
"""Per-forward context for Ring sequence-parallel attention."""
# Ring attention typically doesn't need complex context for post-processing
# as the output is already correctly sharded along sequence dimension.
pass
class RingParallelAttention:
"""Ring sequence-parallel strategy.
This strategy prepares inputs for Ring Attention.
Key responsibilities:
- Concatenate joint_query (Text) to query (Image) if present.
- Keep joint_key/value separate in metadata for the Ring kernel to handle as static prefix.
"""
def __init__(
self,
sp_group: SequenceParallelGroupCoordinator,
attn_backend_pref: str | None = None,
) -> None:
self._sp_group = sp_group
self.attn_backend_pref = attn_backend_pref
@property
def enabled(self) -> bool:
return True
@property
def name(self) -> str:
return "ring"
def pre_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata | None,
):
joint_tensor_query = None
joint_strategy = "front"
if attn_metadata is not None:
joint_tensor_query = attn_metadata.joint_query
joint_strategy = attn_metadata.joint_strategy
if joint_tensor_query is not None:
supported_joint_strategy = ["front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(f"joint_strategy: {joint_strategy} not supported.")
if joint_strategy == "front":
query = torch.cat([joint_tensor_query, query], dim=1)
else:
query = torch.cat([query, joint_tensor_query], dim=1)
# Note: We do NOT concatenate joint_key/value here.
# They are preserved in attn_metadata and will be passed
# explicitly to ring_flash_attn_func.
ctx = _RingCtx(name=self.name)
return query, key, value, attn_metadata, ctx
def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor:
# Ring attention output is already sharded correctly along sequence dimension.
return attn_output
def run_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata | None,
softmax_scale: float | None = None,
causal: bool = False,
) -> torch.Tensor:
"""Run the actual Ring Attention kernel."""
if softmax_scale is None:
softmax_scale = query.shape[-1] ** -0.5
backend_pref = self.attn_backend_pref
if backend_pref is None:
try:
config = get_forward_context().omni_diffusion_config
# config might not have attention_backend attribute if not updated
backend_pref = getattr(config, "attention_backend", None)
except Exception:
backend_pref = None
# Determine attention type with fallback chain: FA3 -> FA2 -> SDPA
# FP32 is not supported by Flash Attention, force SDPA
if query.dtype == torch.float32:
backend_pref = "sdpa"
elif not HAS_FA3 and not HAS_FLASH_ATTN:
if backend_pref != "sdpa":
logger = init_logger(__name__)
logger.warning_once("Flash Attention (FA2/FA3) is not available! Force enabling SDPA.")
backend_pref = "sdpa"
# Extract joint tensors
joint_key, joint_value = None, None
joint_strategy = "front"
if attn_metadata is not None:
joint_key = attn_metadata.joint_key
joint_value = attn_metadata.joint_value
if attn_metadata.joint_strategy is not None:
joint_strategy = attn_metadata.joint_strategy
if backend_pref == "sdpa" or backend_pref == "torch":
from vllm_omni.diffusion.attention.backends.ring_pytorch_attn import ring_pytorch_attn_func
return ring_pytorch_attn_func(
query,
key,
value,
softmax_scale=softmax_scale,
causal=causal,
group=self._sp_group.ring_group,
op_type="efficient",
joint_tensor_key=joint_key,
joint_tensor_value=joint_value,
joint_strategy=joint_strategy,
)
from vllm_omni.diffusion.attention.backends.ring_flash_attn import ring_flash_attn_func
# Prefer FA3 over FA2 for better performance (FA3 supports Ampere/Ada/Hopper)
attn_type = AttnType.FA3 if HAS_FA3 else AttnType.FA
return ring_flash_attn_func(
query,
key,
value,
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=causal,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
group=self._sp_group.ring_group,
attn_type=attn_type,
joint_tensor_key=joint_key,
joint_tensor_value=joint_value,
joint_strategy=joint_strategy,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.distributed as dist
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.parallel.base import ParallelAttentionContext
from vllm_omni.diffusion.distributed.comm import SeqAllToAll4D
from vllm_omni.diffusion.distributed.group_coordinator import SequenceParallelGroupCoordinator
@dataclass(frozen=True, slots=True)
class _UlyssesCtx(ParallelAttentionContext):
"""Per-forward context for Ulysses sequence-parallel attention."""
ulysses_pg: dist.ProcessGroup
scatter_idx: int
gather_idx: int
use_sync: bool
joint_len: int = 0
joint_strategy: str = "front"
class UlyssesParallelAttention:
"""Ulysses sequence-parallel strategy (all-to-all over seq/head dims).
This preserves the semantics previously implemented in
`Attention._forward_ulysses`:
- If `AttentionMetadata.joint_*` is provided, joint_query/key/value are
concatenated *after* all-to-all.
- joint_key/value are assumed to be replicated across SP ranks and are sliced
by ulysses head rank before concatenation.
"""
def __init__(
self,
sp_group: SequenceParallelGroupCoordinator,
scatter_idx: int,
gather_idx: int,
use_sync: bool,
) -> None:
self._sp_group = sp_group
self._ulysses_pg = sp_group.ulysses_group
self._scatter_idx = scatter_idx
self._gather_idx = gather_idx
self._use_sync = use_sync
@property
def enabled(self) -> bool:
return True
@property
def name(self) -> str:
return "ulysses"
def pre_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata | None,
):
joint_tensor_query = joint_tensor_key = joint_tensor_value = None
joint_strategy = "front"
joint_len = 0
if attn_metadata is not None:
joint_tensor_query = attn_metadata.joint_query
joint_tensor_key = attn_metadata.joint_key
joint_tensor_value = attn_metadata.joint_value
joint_strategy = attn_metadata.joint_strategy
is_joint = False
if joint_tensor_query is not None and joint_tensor_key is not None and joint_tensor_value is not None:
supported_joint_strategy = ["front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supported."
f" supported joint strategy: {supported_joint_strategy}"
)
# Slice joint_query for this Ulysses rank
# joint_query is (B, S, H, D). We split H (dim 2).
ulysses_world_size = self._sp_group.ulysses_world_size
ulysses_rank = self._sp_group.ulysses_rank
attn_heads_per_ulysses_rank = joint_tensor_query.shape[-2] // ulysses_world_size
# Note: We use the same heads for Q/K/V
joint_tensor_query = joint_tensor_query[
...,
attn_heads_per_ulysses_rank * ulysses_rank : attn_heads_per_ulysses_rank * (ulysses_rank + 1),
:,
]
joint_len = joint_tensor_query.shape[1]
is_joint = True
elif joint_tensor_query is None and joint_tensor_key is None and joint_tensor_value is None:
pass
else:
raise ValueError("joint_query, joint_key, and joint_value should be None or not None simultaneously.")
if is_joint:
# Slice joint key/value heads for this ulysses rank.
# Using same slicing logic as query
attn_heads_per_ulysses_rank_kv = joint_tensor_key.shape[-2] // ulysses_world_size
joint_tensor_key = joint_tensor_key[
...,
attn_heads_per_ulysses_rank_kv * ulysses_rank : attn_heads_per_ulysses_rank_kv * (ulysses_rank + 1),
:,
]
joint_tensor_value = joint_tensor_value[
...,
attn_heads_per_ulysses_rank_kv * ulysses_rank : attn_heads_per_ulysses_rank_kv * (ulysses_rank + 1),
:,
]
# Update metadata with sliced tensors so Ring attention can use them if needed
if attn_metadata is not None:
attn_metadata.joint_key = joint_tensor_key
attn_metadata.joint_value = joint_tensor_value
# (bs, seq_len/P, head_cnt, head_size) -> (bs, seq_len, head_cnt/P, head_size)
query = SeqAllToAll4D.apply(self._ulysses_pg, query, self._scatter_idx, self._gather_idx, self._use_sync)
key = SeqAllToAll4D.apply(self._ulysses_pg, key, self._scatter_idx, self._gather_idx, self._use_sync)
value = SeqAllToAll4D.apply(self._ulysses_pg, value, self._scatter_idx, self._gather_idx, self._use_sync)
if is_joint:
# Concatenate joint query AFTER AllToAll
# Image query is now (B, S, H/P, D). Joint query is (B, S_txt, H/P, D).
# This is dimensionally consistent.
if joint_strategy == "rear":
query = torch.cat([query, joint_tensor_query], dim=1)
else:
query = torch.cat([joint_tensor_query, query], dim=1)
# Check if Ring Attention is also active (Hybrid mode)
# If Ring is active, we should NOT concatenate joint_key/value to k/v here.
# Instead, they should remain in attn_metadata and be passed to the Ring kernel.
use_ring = self._sp_group.ring_world_size > 1
if is_joint and not use_ring:
# Concatenate joint key/value after all-to-all ONLY for pure Ulysses (Local Attention).
if joint_strategy == "front":
key = torch.cat([joint_tensor_key, key], dim=1)
value = torch.cat([joint_tensor_value, value], dim=1)
else: # "rear"
key = torch.cat([key, joint_tensor_key], dim=1)
value = torch.cat([value, joint_tensor_value], dim=1)
ctx = _UlyssesCtx(
name=self.name,
ulysses_pg=self._ulysses_pg,
scatter_idx=self._scatter_idx,
gather_idx=self._gather_idx,
use_sync=self._use_sync,
joint_len=joint_len,
joint_strategy=joint_strategy,
)
if attn_metadata is not None:
if is_joint:
if attn_metadata.joint_attn_mask is None and attn_metadata.attn_mask is None:
attn_metadata.attn_mask = None
else:
if attn_metadata.attn_mask is None:
attn_metadata.attn_mask = torch.ones(
[query.shape[0], query.shape[1] - attn_metadata.joint_attn_mask.shape[1]],
dtype=torch.bool,
device=query.device,
)
elif attn_metadata.joint_attn_mask is None:
attn_metadata.joint_attn_mask = torch.ones(
[query.shape[0], query.shape[1] - attn_metadata.attn_mask.shape[1]],
dtype=torch.bool,
device=query.device,
)
attn_metadata.attn_mask = (
torch.cat([attn_metadata.joint_attn_mask, attn_metadata.attn_mask], dim=1)
if joint_strategy == "front"
else torch.cat([attn_metadata.attn_mask, attn_metadata.joint_attn_mask], dim=1)
)
if attn_metadata.attn_mask is not None:
# the final attn_mask is ready, the length should be aligedn with query length
assert attn_metadata.attn_mask.shape[1] == query.shape[1], (
f"attn_mask length: {attn_metadata.attn_mask.shape[1]} != query length: {query.shape[1]}"
)
attn_metadata.attn_mask = attn_metadata.attn_mask.bool().contiguous()
return query, key, value, attn_metadata, ctx
def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor:
assert isinstance(ctx, _UlyssesCtx), f"Unexpected ctx type: {type(ctx)!r}"
# If we have joint tensors (Text), they were Head-Sliced.
# The main sequence (Image) was Sequence-Sliced.
# attn_output contains [Joint_Sliced | Image_Sliced] (if strategy='front').
if ctx.joint_len > 0:
joint_len = ctx.joint_len
if ctx.joint_strategy == "front":
output_joint = attn_output[:, :joint_len]
output_img = attn_output[:, joint_len:]
else:
output_img = attn_output[:, :-joint_len]
output_joint = attn_output[:, -joint_len:]
# 1. Process Image part: Standard Ulysses Reverse (AllToAll)
# (bs, seq_len, head_cnt/P, head_size) -> (bs, seq_len/P, head_cnt, head_size)
# SeqAllToAll4D handles: Scatter gather_idx, Gather scatter_idx.
# Forward: Scatter 2 (H), Gather 1 (S).
# Reverse: Scatter 1 (S), Gather 2 (H).
output_img = SeqAllToAll4D.apply(ctx.ulysses_pg, output_img, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync)
# 2. Process Joint part: AllGather on Heads
# Input: (B, JointLen, H/P, D). Output: (B, JointLen, H, D).
# AllGather along dim 2.
# Ensure tensor is contiguous for all_gather (slicing may create non-contiguous views)
output_joint = output_joint.contiguous()
gathered_joint = [torch.zeros_like(output_joint) for _ in range(dist.get_world_size(ctx.ulysses_pg))]
dist.all_gather(gathered_joint, output_joint, group=ctx.ulysses_pg)
output_joint = torch.cat(gathered_joint, dim=2)
# 3. Recombine
if ctx.joint_strategy == "front":
return torch.cat([output_joint, output_img], dim=1)
else:
return torch.cat([output_img, output_joint], dim=1)
# Standard Ulysses Reverse
return SeqAllToAll4D.apply(ctx.ulysses_pg, attn_output, ctx.gather_idx, ctx.scatter_idx, ctx.use_sync)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Diffusion attention backend selector.
This module provides the interface for selecting diffusion attention backends.
The actual backend selection logic is delegated to the platform layer
(vllm_omni.platforms), similar to how vLLM handles attention backend selection.
Usage:
from vllm_omni.diffusion.attention.selector import get_attn_backend
# Get the appropriate backend for current platform
backend_cls = get_attn_backend(head_size=64)
# Or override via environment variable
# export DIFFUSION_ATTENTION_BACKEND=FLASH_ATTN
"""
import importlib
import os
from functools import cache
from vllm.logger import init_logger
from vllm_omni.diffusion.attention.backends.abstract import (
AttentionBackend,
)
logger = init_logger(__name__)
def _load_backend_cls(cls_path: str) -> type[AttentionBackend]:
"""Load a backend class from its fully qualified path.
Args:
cls_path: Fully qualified class path (e.g.,
"vllm_omni.diffusion.attention.backends.sdpa.SDPABackend")
Returns:
The loaded backend class
"""
module_path, class_name = cls_path.rsplit(".", 1)
try:
module = importlib.import_module(module_path)
backend_class = getattr(module, class_name)
return backend_class
except ImportError as e:
raise ImportError(f"Failed to import module {module_path}: {e}")
except AttributeError as e:
raise AttributeError(f"Class {class_name} not found in module: {e}")
@cache
def get_attn_backend(head_size: int) -> type[AttentionBackend]:
"""
Get attention backend for diffusion models.
The backend selection is delegated to the current platform
(vllm_omni.platforms.current_omni_platform), which selects the
appropriate backend based on:
1. User override via DIFFUSION_ATTENTION_BACKEND environment variable
2. Platform-specific defaults and capabilities
This is similar to how vLLM's get_attn_backend_cls works, where the
platform layer decides which backend to use based on hardware capabilities.
Args:
head_size: Head size for attention computation (may affect backend selection)
Returns:
The selected attention backend class
"""
from vllm_omni.platforms import current_omni_platform
# Check environment variable for user override
selected_backend = os.environ.get("DIFFUSION_ATTENTION_BACKEND")
# Delegate to platform for backend selection
backend_cls_path = current_omni_platform.get_diffusion_attn_backend_cls(
selected_backend=selected_backend,
head_size=head_size,
)
return _load_backend_cls(backend_cls_path)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Cache module for diffusion model inference acceleration.
This module provides a unified cache backend system for different caching strategies:
- TeaCache: Timestep Embedding Aware Cache for adaptive transformer caching
- cache-dit: DBCache, SCM, and TaylorSeer caching strategies
Cache backends are instantiated directly via their constructors and configured via OmniDiffusionConfig.
"""
from vllm_omni.diffusion.cache.base import CacheBackend
from vllm_omni.diffusion.cache.teacache import (
CacheContext,
TeaCacheConfig,
apply_teacache_hook,
)
from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend
__all__ = [
"CacheBackend",
"TeaCacheConfig",
"CacheContext",
"TeaCacheBackend",
"apply_teacache_hook",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Base cache backend interface for diffusion models.
This module defines the abstract base class that all cache backends must implement.
Cache backends provide a unified interface for applying different caching strategies
to transformer models.
Main cache backend implementations:
1. CacheDiTBackend: Implements cache-dit acceleration (DBCache, SCM, TaylorSeer) using
the cache-dit library. Inherits from CacheBackend. Used via cache_backend="cache_dit".
2. TeaCacheBackend: Hook-based backend for TeaCache acceleration. Inherits from
CacheBackend. Used via cache_backend="tea_cache".
All backends implement the same interface:
- enable(pipeline): Enable cache on the pipeline
- refresh(pipeline, num_inference_steps, verbose): Refresh cache state
- is_enabled(): Check if cache is enabled
"""
from abc import ABC, abstractmethod
from typing import Any
import torch.nn as nn
from vllm_omni.diffusion.data import DiffusionCacheConfig
class CacheBackend(ABC):
"""
Abstract base class for cache backends.
All cache backend implementations (CacheDiTBackend, TeaCacheBackend, etc.) inherit
from this base class and implement the enable() and refresh() methods to manage
cache lifecycle.
Cache backends apply caching strategies to transformer models to accelerate
inference. Different backends use different underlying mechanisms (e.g., cache-dit
library for CacheDiTBackend, hooks for TeaCacheBackend), but all share the same
unified interface.
Attributes:
config: DiffusionCacheConfig instance containing cache-specific configuration parameters
enabled: Boolean flag indicating whether cache is enabled (set to True after enable() is called)
"""
def __init__(self, config: DiffusionCacheConfig):
"""
Initialize cache backend with configuration.
Args:
config: DiffusionCacheConfig instance with cache-specific parameters
"""
self.config = config
self.enabled = False
@abstractmethod
def enable(self, pipeline: Any) -> None:
"""
Enable cache on the pipeline.
This method applies the caching strategy to the transformer(s) in the pipeline.
The specific implementation depends on the backend (e.g., hooks for TeaCacheBackend,
cache-dit library for CacheDiTBackend). Called once during pipeline initialization.
Args:
pipeline: Diffusion pipeline instance. The backend can extract:
- transformer: via pipeline.transformer
- model_type: via pipeline.__class__.__name__
"""
raise NotImplementedError("Subclasses must implement enable()")
@abstractmethod
def refresh(self, pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
"""
Refresh cache state for new generation.
This method should clear any cached values and reset counters/accumulators.
Called at the start of each generation to ensure clean state.
Args:
pipeline: Diffusion pipeline instance. The backend can extract:
- transformer: via pipeline.transformer
num_inference_steps: Number of inference steps for the current generation.
May be used for cache context updates.
verbose: Whether to log refresh operations (default: True)
"""
raise NotImplementedError("Subclasses must implement refresh()")
def is_enabled(self) -> bool:
"""
Check if cache is enabled on this backend.
Returns:
True if cache is enabled, False otherwise.
"""
return self.enabled
def __repr__(self) -> str:
return f"{self.__class__.__name__}(config={self.config})"
class CachedTransformer(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.do_true_cfg = False
def __init_subclass__(cls, enable_separate_cfg: bool = True, **kwargs):
cls.enable_separate_cfg = enable_separate_cfg
super().__init_subclass__(**kwargs)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
cache-dit integration backend for vllm-omni.
This module provides a CacheDiTBackend class to enable cache-dit acceleration on diffusion
pipelines in vllm-omni, supporting both single and dual-transformer architectures.
"""
import functools
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any, Optional
import cache_dit
import torch
from cache_dit import BlockAdapter, DBCacheConfig, ForwardPattern, ParamsModifier, TaylorSeerCalibratorConfig
from cache_dit.caching.block_adapters import FakeDiffusionPipeline
from cache_dit.caching.cache_adapters.cache_adapter import CachedAdapter
from cache_dit.caching.cache_blocks.pattern_0_1_2 import CachedBlocks_Pattern_0_1_2
from cache_dit.caching.cache_contexts import BasicCacheConfig
from cache_dit.caching.cache_contexts.cache_manager import CachedContextManager
from vllm.logger import init_logger
from vllm_omni.diffusion.cache.base import CacheBackend
from vllm_omni.diffusion.data import DiffusionCacheConfig, OmniDiffusionConfig
logger = init_logger(__name__)
# Small helper to centralize cache-dit summaries.
def cache_summary(pipeline: Any, details: bool = True) -> None:
cache_dit.summary(pipeline.transformer, details=details)
if hasattr(pipeline, "transformer_2"):
cache_dit.summary(pipeline.transformer_2, details=details)
# Registry of custom cache-dit enablers for specific models
# Maps pipeline names to their cache-dit enablement functions
# Models in this registry require custom handling (e.g., dual-transformer architectures)
# Will be populated after function definitions
CUSTOM_DIT_ENABLERS: dict[str, Callable] = {}
def _build_db_cache_config(cache_config: Any) -> DBCacheConfig:
"""Build DBCacheConfig with optional SCM (Step Computation Masking) support.
Args:
cache_config: DiffusionCacheConfig instance.
Returns:
DBCacheConfig instance with SCM support if configured.
"""
return DBCacheConfig(
# we will refresh the context when gets num_inference_steps in the first inference request
num_inference_steps=None,
Fn_compute_blocks=cache_config.Fn_compute_blocks,
Bn_compute_blocks=cache_config.Bn_compute_blocks,
max_warmup_steps=cache_config.max_warmup_steps,
max_cached_steps=cache_config.max_cached_steps,
max_continuous_cached_steps=cache_config.max_continuous_cached_steps,
residual_diff_threshold=cache_config.residual_diff_threshold,
)
def enable_cache_for_wan22(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for Wan2.2 dual-transformer architecture.
Wan2.2 uses two transformers (transformer and transformer_2) that need
to be enabled together using BlockAdapter.
Args:
pipeline: The Wan2.2 pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called to update cache context with new num_inference_steps.
"""
cache_dit.enable_cache(
BlockAdapter(
transformer=[
pipeline.transformer,
pipeline.transformer_2,
],
blocks=[
pipeline.transformer.blocks,
pipeline.transformer_2.blocks,
],
forward_pattern=[
ForwardPattern.Pattern_2,
ForwardPattern.Pattern_2,
],
params_modifiers=[
# high-noise transformer only have 30% steps
ParamsModifier(
cache_config=DBCacheConfig().reset(
max_warmup_steps=cache_config.max_warmup_steps,
max_cached_steps=cache_config.max_cached_steps,
),
),
ParamsModifier(
cache_config=DBCacheConfig().reset(
max_warmup_steps=2,
max_cached_steps=20,
),
),
],
has_separate_cfg=True,
),
cache_config=DBCacheConfig(
Fn_compute_blocks=cache_config.Fn_compute_blocks,
Bn_compute_blocks=cache_config.Bn_compute_blocks,
max_warmup_steps=cache_config.max_warmup_steps,
max_cached_steps=cache_config.max_cached_steps,
max_continuous_cached_steps=cache_config.max_continuous_cached_steps,
residual_diff_threshold=cache_config.residual_diff_threshold,
num_inference_steps=None,
),
)
# from https://github.com/vipshop/cache-dit/pull/542
def _split_inference_steps(num_inference_steps: int) -> tuple[int, int]:
"""Split inference steps into high-noise and low-noise steps for Wan2.2.
This is an internal helper function specific to Wan2.2's dual-transformer
architecture that uses boundary_ratio to determine the split point.
Args:
num_inference_steps: Total number of inference steps.
Returns:
A tuple of (num_high_noise_steps, num_low_noise_steps).
"""
if pipeline.boundary_ratio is not None:
boundary_timestep = pipeline.boundary_ratio * pipeline.scheduler.config.num_train_timesteps
else:
boundary_timestep = None
# Set timesteps to calculate the split
device = next(pipeline.transformer.parameters()).device
pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipeline.scheduler.timesteps
num_high_noise_steps = 0 # high-noise steps for transformer
for t in timesteps:
if boundary_timestep is None or t >= boundary_timestep:
num_high_noise_steps += 1
# low-noise steps for transformer_2
num_low_noise_steps = num_inference_steps - num_high_noise_steps
return num_high_noise_steps, num_low_noise_steps
def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
"""Refresh cache context for both transformers with new num_inference_steps.
Args:
pipeline: The Wan2.2 pipeline instance.
num_inference_steps: New number of inference steps.
"""
num_high_noise_steps, num_low_noise_steps = _split_inference_steps(num_inference_steps)
# Refresh context for high-noise transformer
if cache_config.scm_steps_mask_policy is None:
# cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_high_noise_steps, verbose=verbose)
cache_dit.refresh_context(
pipeline.transformer,
num_inference_steps=num_high_noise_steps,
verbose=verbose,
)
cache_dit.refresh_context(
pipeline.transformer_2,
num_inference_steps=num_low_noise_steps,
verbose=verbose,
)
else:
cache_dit.refresh_context(
pipeline.transformer,
cache_config=DBCacheConfig().reset(
num_inference_steps=num_high_noise_steps,
steps_computation_mask=cache_dit.steps_mask(
mask_policy=cache_config.scm_steps_mask_policy, total_steps=num_high_noise_steps
),
steps_computation_policy=cache_config.scm_steps_policy,
),
verbose=verbose,
)
cache_dit.refresh_context(
pipeline.transformer_2,
cache_config=DBCacheConfig().reset(
num_inference_steps=num_low_noise_steps,
steps_computation_mask=cache_dit.steps_mask(
mask_policy=cache_config.scm_steps_mask_policy, total_steps=num_low_noise_steps
),
steps_computation_policy=cache_config.scm_steps_policy,
),
verbose=verbose,
)
return refresh_cache_context
def enable_cache_for_longcat_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for LongCatImage pipeline.
Args:
pipeline: The LongCatImage pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
"""
# Build DBCacheConfig for transformer
db_cache_config = _build_db_cache_config(cache_config)
calibrator = None
if cache_config.enable_taylorseer:
taylorseer_order = cache_config.taylorseer_order
calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
logger.info(f"TaylorSeer enabled with order={taylorseer_order}")
# Build ParamsModifier for transformer
modifier = ParamsModifier(
cache_config=db_cache_config,
calibrator_config=calibrator,
)
logger.info(
f"Enabling cache-dit on LongCatImage transformer with BlockAdapter: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
)
# Enable cache-dit using BlockAdapter for transformer
cache_dit.enable_cache(
(
BlockAdapter(
transformer=pipeline.transformer,
blocks=[
pipeline.transformer.transformer_blocks,
pipeline.transformer.single_transformer_blocks,
],
forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1],
params_modifiers=[modifier],
)
),
cache_config=db_cache_config,
)
def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
"""Refresh cache context for the transformer with new num_inference_steps.
Args:
pipeline: The LongCatImage pipeline instance.
num_inference_steps: New number of inference steps.
"""
if cache_config.scm_steps_mask_policy is None:
cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
else:
cache_dit.refresh_context(
pipeline.transformer,
cache_config=DBCacheConfig().reset(
num_inference_steps=num_inference_steps,
steps_computation_mask=cache_dit.steps_mask(
mask_policy=cache_config.scm_steps_mask_policy,
total_steps=num_inference_steps,
),
steps_computation_policy=cache_config.scm_steps_policy,
),
verbose=verbose,
)
return refresh_cache_context
def enable_cache_for_flux(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for Flux dual-transformer architecture.
Flux uses two transformers (transformer and transformer_2) that need
to be enabled together using BlockAdapter.
Args:
pipeline: The Flux pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called to update cache context with new num_inference_steps.
"""
raise NotImplementedError("cache-dit is not implemented for Flux pipeline.")
def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for StableDiffusion3Pipeline.
Args:
pipeline: The StableDiffusion3 pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
"""
# Build DBCacheConfig for transformer
db_cache_config = _build_db_cache_config(cache_config)
calibrator = None
if cache_config.enable_taylorseer:
taylorseer_order = cache_config.taylorseer_order
calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
logger.info(f"TaylorSeer enabled with order={taylorseer_order}")
# Build ParamsModifier for transformer
modifier = ParamsModifier(
cache_config=db_cache_config,
calibrator_config=calibrator,
)
logger.info(
f"Enabling cache-dit on StableDiffusion3 transformer with BlockAdapter: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
)
# Enable cache-dit using BlockAdapter for transformer
cache_dit.enable_cache(
(
BlockAdapter(
transformer=pipeline.transformer,
blocks=pipeline.transformer.transformer_blocks,
forward_pattern=ForwardPattern.Pattern_1,
params_modifiers=[modifier],
)
),
cache_config=db_cache_config,
)
def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
"""Refresh cache context for the transformer with new num_inference_steps.
Args:
pipeline: The LongCatImage pipeline instance.
num_inference_steps: New number of inference steps.
"""
if cache_config.scm_steps_mask_policy is None:
cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
else:
cache_dit.refresh_context(
pipeline.transformer,
cache_config=DBCacheConfig().reset(
num_inference_steps=num_inference_steps,
steps_computation_mask=cache_dit.steps_mask(
mask_policy=cache_config.scm_steps_mask_policy,
total_steps=num_inference_steps,
),
steps_computation_policy=cache_config.scm_steps_policy,
),
verbose=verbose,
)
return refresh_cache_context
def enable_cache_for_dit(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for regular single-transformer DiT models.
Args:
pipeline: The diffusion pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called to update cache context with new num_inference_steps.
"""
# Build DBCacheConfig with optional SCM support
db_cache_config = _build_db_cache_config(cache_config)
# Build calibrator config if TaylorSeer is enabled
calibrator_config = None
if cache_config.enable_taylorseer:
taylorseer_order = cache_config.taylorseer_order
calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
logger.info(f"TaylorSeer enabled with order={taylorseer_order}")
logger.info(
f"Enabling cache-dit on transformer: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
)
# Enable cache-dit on the transformer
cache_dit.enable_cache(
pipeline.transformer,
cache_config=db_cache_config,
calibrator_config=calibrator_config,
)
def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
"""Refresh cache context for the transformer with new num_inference_steps.
Args:
pipeline: The diffusion pipeline instance.
num_inference_steps: New number of inference steps.
"""
if cache_config.scm_steps_mask_policy is None:
cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
else:
cache_dit.refresh_context(
pipeline.transformer,
cache_config=DBCacheConfig().reset(
num_inference_steps=num_inference_steps,
steps_computation_mask=cache_dit.steps_mask(
mask_policy=cache_config.scm_steps_mask_policy,
total_steps=num_inference_steps,
),
steps_computation_policy=cache_config.scm_steps_policy,
),
verbose=verbose,
)
return refresh_cache_context
class BagelCachedContextManager(CachedContextManager):
"""
Custom CachedContextManager for Bagel that safely handles NaiveCache objects
(mapped to encoder_hidden_states) by skipping tensor operations on them.
"""
@torch.compiler.disable
def apply_cache(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
prefix: str = "Bn",
encoder_prefix: str = "Bn_encoder",
) -> tuple[torch.Tensor, torch.Tensor | None]:
# Allow Bn and Fn prefix to be used for residual cache.
if "Bn" in prefix:
hidden_states_prev = self.get_Bn_buffer(prefix)
else:
hidden_states_prev = self.get_Fn_buffer(prefix)
assert hidden_states_prev is not None, f"{prefix}_buffer must be set before"
if self.is_cache_residual():
hidden_states = hidden_states_prev + hidden_states
else:
# If cache is not residual, we use the hidden states directly
hidden_states = hidden_states_prev
hidden_states = hidden_states.contiguous()
if encoder_hidden_states is not None:
if "Bn" in encoder_prefix:
encoder_hidden_states_prev = self.get_Bn_encoder_buffer(encoder_prefix)
else:
encoder_hidden_states_prev = self.get_Fn_encoder_buffer(encoder_prefix)
if encoder_hidden_states_prev is not None:
if self.is_encoder_cache_residual():
# FIX: Check if encoder_hidden_states is a tensor before adding
if isinstance(encoder_hidden_states, torch.Tensor) and isinstance(
encoder_hidden_states_prev, torch.Tensor
):
encoder_hidden_states = encoder_hidden_states_prev + encoder_hidden_states
else:
# If encoder cache is not residual, we use the encoder hidden states directly
encoder_hidden_states = encoder_hidden_states_prev
# FIX: Check if encoder_hidden_states is a tensor before calling contiguous
if isinstance(encoder_hidden_states, torch.Tensor):
encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states
class BagelCachedBlocks(CachedBlocks_Pattern_0_1_2):
"""
Custom CachedBlocks for Bagel that safely handles NaiveCache objects
by adding isinstance checks in call_Mn_blocks and compute_or_prune.
"""
def call_Mn_blocks(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
*args,
**kwargs,
):
original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
for block in self._Mn_blocks():
hidden_states = block(
hidden_states,
encoder_hidden_states,
*args,
**kwargs,
)
hidden_states, encoder_hidden_states = self._process_block_outputs(hidden_states, encoder_hidden_states)
# compute hidden_states residual
hidden_states = hidden_states.contiguous()
hidden_states_residual = hidden_states - original_hidden_states
if (
encoder_hidden_states is not None
and original_encoder_hidden_states is not None
and isinstance(encoder_hidden_states, torch.Tensor) # FIX: Added Check
):
encoder_hidden_states = encoder_hidden_states.contiguous()
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
else:
encoder_hidden_states_residual = None
return (
hidden_states,
encoder_hidden_states,
hidden_states_residual,
encoder_hidden_states_residual,
)
def compute_or_prune(
self,
block_id: int, # Block index in the transformer blocks
# Below are the inputs to the block
block, # The transformer block to be executed
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
*args,
**kwargs,
):
# NOTE: Although Bagel likely won't use pruning, implementing safe version just in case.
# Copy-pasted from original but adding checks.
original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
can_use_prune = self._maybe_prune(
block_id,
hidden_states,
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
)
torch._dynamo.graph_break()
if can_use_prune:
self.context_manager.add_pruned_step()
hidden_states, encoder_hidden_states = self.context_manager.apply_prune(
hidden_states,
encoder_hidden_states,
prefix=(
f"{self.cache_prefix}_{block_id}_Bn_residual"
if self.context_manager.is_cache_residual()
else f"{self.cache_prefix}_Bn_hidden_states"
),
encoder_prefix=(
f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
if self.context_manager.is_encoder_cache_residual()
else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
),
)
torch._dynamo.graph_break()
else:
# Normal steps: Compute the block and cache the residuals.
hidden_states = block(
hidden_states,
encoder_hidden_states,
*args,
**kwargs,
)
hidden_states, encoder_hidden_states = self._process_block_outputs(hidden_states, encoder_hidden_states)
if not self._skip_prune(block_id):
hidden_states = hidden_states.contiguous()
hidden_states_residual = hidden_states - original_hidden_states
if (
encoder_hidden_states is not None
and original_encoder_hidden_states is not None
and isinstance(encoder_hidden_states, torch.Tensor) # FIX: Added Check
):
encoder_hidden_states = encoder_hidden_states.contiguous()
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
else:
encoder_hidden_states_residual = None
self.context_manager.set_Fn_buffer(
original_hidden_states,
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
)
if self.context_manager.is_cache_residual():
self.context_manager.set_Bn_buffer(
hidden_states_residual,
prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
)
else:
self.context_manager.set_Bn_buffer(
hidden_states,
prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
)
if encoder_hidden_states_residual is not None:
if self.context_manager.is_encoder_cache_residual():
self.context_manager.set_Bn_encoder_buffer(
encoder_hidden_states_residual,
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
)
else:
self.context_manager.set_Bn_encoder_buffer(
encoder_hidden_states_residual,
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
)
torch._dynamo.graph_break()
return hidden_states, encoder_hidden_states
class BagelCachedAdapter(CachedAdapter):
"""
Custom CachedAdapter for Bagel that uses BagelCachedContextManager and BagelCachedBlocks.
"""
@classmethod
def create_context(
cls,
block_adapter: BlockAdapter,
**context_kwargs,
) -> tuple[list[str], list[dict[str, Any]]]:
# Override to use BagelCachedContextManager
BlockAdapter.assert_normalized(block_adapter)
if BlockAdapter.is_cached(block_adapter.pipe):
return block_adapter.pipe
# Check context_kwargs
context_kwargs = cls.check_context_kwargs(block_adapter, **context_kwargs)
# Each Pipeline should have it's own context manager instance.
cache_config: BasicCacheConfig = context_kwargs.get("cache_config", None)
assert cache_config is not None, "cache_config can not be None."
# Apply cache on pipeline: wrap cache context
pipe_cls_name = block_adapter.pipe.__class__.__name__
# USE CUSTOM CONTEXT MANAGER
context_manager = BagelCachedContextManager(
name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
persistent_context=isinstance(block_adapter.pipe, FakeDiffusionPipeline),
)
flatten_contexts, contexts_kwargs = cls.modify_context_params(block_adapter, **context_kwargs)
block_adapter.pipe._context_manager = context_manager # instance level
if not context_manager.persistent_context:
original_call = block_adapter.pipe.__class__.__call__
@functools.wraps(original_call)
def new_call(self, *args, **kwargs):
with ExitStack() as stack:
# cache context will be reset for each pipe inference
for context_name, context_kwargs in zip(flatten_contexts, contexts_kwargs):
stack.enter_context(
context_manager.enter_context(
context_manager.reset_context(
context_name,
**context_kwargs,
),
)
)
outputs = original_call(self, *args, **kwargs)
cls.apply_stats_hooks(block_adapter)
return outputs
block_adapter.pipe.__class__.__call__ = new_call
block_adapter.pipe.__class__._original_call = original_call
else:
# Init persistent cache context for transformer
for context_name, context_kwargs in zip(flatten_contexts, contexts_kwargs):
context_manager.reset_context(
context_name,
**context_kwargs,
)
block_adapter.pipe.__class__._is_cached = True
cls.apply_params_hooks(block_adapter, contexts_kwargs)
return flatten_contexts, contexts_kwargs
@classmethod
def collect_unified_blocks(
cls,
block_adapter: BlockAdapter,
contexts_kwargs: list[dict],
) -> list[dict[str, torch.nn.ModuleList]]:
# Override to use BagelCachedBlocks
BlockAdapter.assert_normalized(block_adapter)
total_cached_blocks: list[dict[str, torch.nn.ModuleList]] = []
assert hasattr(block_adapter.pipe, "_context_manager")
# Skipping isinstance check for ContextManager._supported_managers to avoid import issues
for i in range(len(block_adapter.transformer)):
unified_blocks_bind_context = {}
for j in range(len(block_adapter.blocks[i])):
cache_config: BasicCacheConfig = contexts_kwargs[i * len(block_adapter.blocks[i]) + j]["cache_config"]
# Directly instantiate BagelCachedBlocks
unified_blocks_bind_context[block_adapter.unique_blocks_name[i][j]] = torch.nn.ModuleList(
[
BagelCachedBlocks(
# 0. Transformer blocks configuration
block_adapter.blocks[i][j],
transformer=block_adapter.transformer[i],
forward_pattern=block_adapter.forward_pattern[i][j],
check_forward_pattern=block_adapter.check_forward_pattern,
check_num_outputs=block_adapter.check_num_outputs,
# 1. Cache/Prune context configuration
cache_prefix=block_adapter.blocks_name[i][j],
cache_context=block_adapter.unique_blocks_name[i][j],
context_manager=block_adapter.pipe._context_manager,
cache_type=cache_config.cache_type,
)
]
)
total_cached_blocks.append(unified_blocks_bind_context)
return total_cached_blocks
def enable_cache_for_bagel(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for Bagel model (via OmniDiffusion pipeline).
Args:
pipeline: The OmniDiffusion pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called to update cache context with new num_inference_steps.
"""
# Build DBCacheConfig
db_cache_config = _build_db_cache_config(cache_config)
# Build calibrator config if TaylorSeer is enabled
calibrator_config = None
if cache_config.enable_taylorseer:
taylorseer_order = cache_config.taylorseer_order
calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
logger.info(f"TaylorSeer enabled with order={taylorseer_order}")
# Access the transformer: BagelPipeline -> Qwen2MoTForCausalLM -> Qwen2MoTModel
# BagelPipeline has self.language_model which is Qwen2MoTForCausalLM
# Qwen2MoTForCausalLM has self.model which is Qwen2MoTModel
transformer = pipeline.language_model.model
logger.info(
f"Enabling cache-dit on Bagel transformer: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
)
# Enable cache-dit on the transformer
# Pattern_0 corresponds to (hidden_states, encoder_hidden_states) input, output
# Custom adapter for Bagel to handle NaiveCache correctly
# from vllm_omni.diffusion.cache.bagel_cache_adapter import BagelCachedAdapter # No longer needed
BagelCachedAdapter.apply(
BlockAdapter(
transformer=transformer,
blocks=transformer.layers,
forward_pattern=ForwardPattern.Pattern_0,
),
cache_config=db_cache_config,
calibrator_config=calibrator_config,
)
def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
transformer = pipeline.language_model.model
if cache_config.scm_steps_mask_policy is None:
cache_dit.refresh_context(transformer, num_inference_steps=num_inference_steps, verbose=verbose)
else:
cache_dit.refresh_context(
transformer,
cache_config=DBCacheConfig().reset(
num_inference_steps=num_inference_steps,
steps_computation_mask=cache_dit.steps_mask(
mask_policy=cache_config.scm_steps_mask_policy,
total_steps=num_inference_steps,
),
steps_computation_policy=cache_config.scm_steps_policy,
),
verbose=verbose,
)
return refresh_cache_context
# Register custom cache-dit enablers after function definitions
CUSTOM_DIT_ENABLERS.update(
{
"Wan22Pipeline": enable_cache_for_wan22,
"Wan22I2VPipeline": enable_cache_for_wan22,
"Wan22TI2VPipeline": enable_cache_for_wan22,
"FluxPipeline": enable_cache_for_flux,
"LongCatImagePipeline": enable_cache_for_longcat_image,
"LongCatImageEditPipeline": enable_cache_for_longcat_image,
"StableDiffusion3Pipeline": enable_cache_for_sd3,
"BagelPipeline": enable_cache_for_bagel,
}
)
class CacheDiTBackend(CacheBackend):
"""Backend class for cache-dit acceleration on diffusion pipelines.
This class implements cache-dit acceleration (DBCache, SCM, TaylorSeer) using
the cache-dit library. It inherits from CacheBackend and provides a unified
interface for managing cache-dit acceleration on diffusion models.
Attributes:
config: Cache configuration (DiffusionCacheConfig instance), inherited from CacheBackend.
enabled: Whether cache-dit is enabled on this pipeline, inherited from CacheBackend.
_refresh_func: Internal refresh function for updating cache context.
_last_num_inference_steps: Last num_inference_steps used for refresh optimization.
"""
def __init__(self, cache_config: Any = None):
"""Initialize the cache-dit backend.
Args:
cache_config: Cache configuration (DiffusionCacheConfig instance, dict, or None).
If None or empty, uses default DiffusionCacheConfig().
"""
# Use default config if cache_config is not provided or is empty
if cache_config is None:
config = DiffusionCacheConfig()
elif isinstance(cache_config, dict):
# Convert dict to DiffusionCacheConfig, using defaults for missing keys
config = DiffusionCacheConfig.from_dict(cache_config)
else:
config = cache_config
# Initialize base class with normalized config
super().__init__(config)
# Cache-dit specific attributes
self._refresh_func: Callable[[Any, int, bool], None] | None = None
self._last_num_inference_steps: int | None = None
def enable(self, pipeline: Any) -> None:
"""Enable cache-dit on the pipeline if configured.
This method applies cache-dit acceleration to the appropriate transformer(s)
in the pipeline. It handles both single-transformer and dual-transformer
architectures (e.g., Wan2.2).
Args:
pipeline: The diffusion pipeline instance.
"""
# Extract pipeline name from pipeline
pipeline_name = pipeline.__class__.__name__
# Check if this model has a custom cache-dit enabler
if pipeline_name in CUSTOM_DIT_ENABLERS:
logger.info(f"Using custom cache-dit enabler for model: {pipeline_name}")
self._refresh_func = CUSTOM_DIT_ENABLERS[pipeline_name](pipeline, self.config)
else:
# For regular single-transformer models
self._refresh_func = enable_cache_for_dit(pipeline, self.config)
self.enabled = True
logger.info(f"Cache-dit enabled successfully on {pipeline_name}")
def refresh(self, pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
"""Refresh cache context with new num_inference_steps.
This method updates the cache context when num_inference_steps changes
during inference. For dual-transformer models (e.g., Wan2.2), it automatically
splits the steps based on boundary_ratio.
Args:
pipeline: The diffusion pipeline instance.
num_inference_steps: New number of inference steps.
verbose: Whether to log refresh operations.
"""
if not self.enabled or self._refresh_func is None:
logger.warning("Cache-dit is not enabled. Cannot refresh cache context.")
return
# Only refresh if num_inference_steps has changed
if self._last_num_inference_steps is None or num_inference_steps != self._last_num_inference_steps:
if verbose:
logger.info(f"Refreshing cache context for transformer with num_inference_steps: {num_inference_steps}")
self._refresh_func(pipeline, num_inference_steps, verbose)
self._last_num_inference_steps = num_inference_steps
def is_enabled(self) -> bool:
"""Check if cache-dit is enabled on this pipeline.
Returns:
True if cache-dit is enabled, False otherwise.
"""
return self.enabled
def may_enable_cache_dit(pipeline: Any, od_config: OmniDiffusionConfig) -> Optional["CacheDiTBackend"]:
"""Enable cache-dit on the pipeline if configured (convenience function).
This is a convenience function that creates and enables a CacheDiTBackend.
For new code, consider using CacheDiTBackend directly.
Args:
pipeline: The diffusion pipeline instance.
od_config: OmniDiffusionConfig with cache configuration.
Returns:
A CacheDiTBackend instance if cache-dit is enabled, None otherwise.
"""
if od_config.cache_backend != "cache-dit" or not od_config.cache_config:
return None
backend = CacheDiTBackend(od_config.cache_config)
backend.enable(pipeline)
return backend if backend.is_enabled() else None
from typing import Any
from vllm_omni.diffusion.cache.base import CacheBackend
from vllm_omni.diffusion.cache.cache_dit_backend import CacheDiTBackend
from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend
from vllm_omni.diffusion.data import DiffusionCacheConfig
def get_cache_backend(cache_backend: str | None, cache_config: Any) -> CacheBackend | None:
"""Get cache backend instance based on cache_backend string.
This is a selector function that routes to the appropriate backend implementation.
- cache_dit: Uses CacheDiTBackend with enable()/refresh() interface
- tea_cache: Uses TeaCacheBackend with enable()/refresh() interface
Args:
cache_backend: Cache backend name ("cache_dit", "tea_cache", or None).
cache_config: Cache configuration (dict or DiffusionCacheConfig instance).
Returns:
Cache backend instance (CacheDiTBackend or TeaCacheBackend) if cache_backend is set,
None otherwise.
Raises:
ValueError: If cache_backend is unsupported.
"""
if cache_backend is None or cache_backend == "none":
return None
if isinstance(cache_config, dict):
cache_config = DiffusionCacheConfig.from_dict(cache_config)
if cache_backend == "cache_dit":
return CacheDiTBackend(cache_config)
elif cache_backend == "tea_cache":
return TeaCacheBackend(cache_config)
else:
raise ValueError(f"Unsupported cache backend: {cache_backend}. Supported: 'cache_dit', 'tea_cache'")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
TeaCache: Timestep Embedding Aware Cache for diffusion model acceleration.
TeaCache speeds up diffusion inference by reusing transformer block computations
when consecutive timestep embeddings are similar.
This implementation uses a hooks-based approach that requires zero changes to
model code. Model developers only need to add an extractor function to support
new models.
Usage:
from vllm_omni import Omni
omni = Omni(
model="Qwen/Qwen-Image",
cache_backend="tea_cache",
cache_config={"rel_l1_thresh": 0.2}
)
images = omni.generate("a cat")
# Alternative: Using environment variable
# export DIFFUSION_CACHE_BACKEND=tea_cache
"""
from vllm_omni.diffusion.cache.teacache.backend import TeaCacheBackend
from vllm_omni.diffusion.cache.teacache.config import TeaCacheConfig
from vllm_omni.diffusion.cache.teacache.extractors import (
CacheContext,
register_extractor,
)
from vllm_omni.diffusion.cache.teacache.hook import TeaCacheHook, apply_teacache_hook
from vllm_omni.diffusion.cache.teacache.state import TeaCacheState
__all__ = [
"TeaCacheBackend",
"TeaCacheConfig",
"TeaCacheState",
"TeaCacheHook",
"apply_teacache_hook",
"register_extractor",
"CacheContext",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
TeaCache backend implementation.
This module provides the TeaCache backend that implements the CacheBackend
interface using the hooks-based TeaCache system.
"""
from typing import Any
from vllm.logger import init_logger
from vllm_omni.diffusion.cache.base import CacheBackend
from vllm_omni.diffusion.cache.teacache.config import TeaCacheConfig
from vllm_omni.diffusion.cache.teacache.hook import TeaCacheHook, apply_teacache_hook
from vllm_omni.diffusion.data import DiffusionCacheConfig
logger = init_logger(__name__)
def enable_bagel_teacache(pipeline: Any, config: DiffusionCacheConfig) -> None:
"""
Enable TeaCache for Bagel model.
"""
teacache_config = TeaCacheConfig(
transformer_type="Bagel",
rel_l1_thresh=config.rel_l1_thresh,
coefficients=config.coefficients,
)
transformer = pipeline.bagel
original_forward_flow = transformer._forward_flow
import types
def forward_alias(self, *args, **kwargs):
return original_forward_flow(*args, **kwargs)
transformer.forward = types.MethodType(forward_alias, transformer)
apply_teacache_hook(transformer, teacache_config)
transformer._forward_flow = transformer.forward
pipeline.transformer = transformer
logger.info(
f"TeaCache applied with rel_l1_thresh={teacache_config.rel_l1_thresh}, "
f"transformer_class={teacache_config.transformer_type}"
)
CUSTOM_TEACACHE_ENABLERS = {"BagelPipeline": enable_bagel_teacache}
class TeaCacheBackend(CacheBackend):
"""
TeaCache implementation using hooks.
TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique
that speeds up diffusion inference by reusing transformer block computations
when consecutive timestep embeddings are similar.
The backend applies TeaCache hooks to the transformer which intercept the
forward pass and implement the caching logic transparently.
Example:
>>> from vllm_omni.diffusion.data import DiffusionCacheConfig
>>> backend = TeaCacheBackend(DiffusionCacheConfig(rel_l1_thresh=0.2))
>>> backend.enable(pipeline)
>>> # Generate with cache enabled
>>> backend.refresh(pipeline, num_inference_steps=50) # Refresh before each generation
>>> # Access config attributes: backend.config.rel_l1_thresh
"""
def enable(self, pipeline: Any) -> None:
"""
Enable TeaCache on transformer using hooks.
This creates a TeaCacheConfig from the backend's DiffusionCacheConfig
and applies the TeaCache hook to the transformer.
Args:
pipeline: Diffusion pipeline instance. Extracts transformer and transformer_type:
- transformer: pipeline.transformer
- transformer_type: pipeline.transformer.__class__.__name__
"""
# Helper to get pipeline class name
pipeline_type = pipeline.__class__.__name__
# Check for pipeline-level custom enablers
if pipeline_type in CUSTOM_TEACACHE_ENABLERS:
logger.info(f"Using custom TeaCache enabler for model: {pipeline_type}")
CUSTOM_TEACACHE_ENABLERS[pipeline_type](pipeline, self.config)
else:
transformer = pipeline.transformer
transformer_type = transformer.__class__.__name__
# Create TeaCacheConfig from DiffusionCacheConfig with transformer_type
# Access parameters via attribute access: config.rel_l1_thresh
# rel_l1_thresh already has a default value of 0.2 in DiffusionCacheConfig
try:
teacache_config = TeaCacheConfig(
transformer_type=transformer_type,
rel_l1_thresh=self.config.rel_l1_thresh,
coefficients=self.config.coefficients,
)
except Exception as e:
logger.error(f"Failed to create TeaCacheConfig: {e}")
raise ValueError(
f"Invalid TeaCache configuration: {e}. "
f"Expected keys: rel_l1_thresh, coefficients (optional). "
f"transformer_type is automatically extracted from pipeline.transformer.__class__.__name__."
)
# Apply hook to transformer
apply_teacache_hook(transformer, teacache_config)
logger.info(
f"TeaCache applied with rel_l1_thresh={teacache_config.rel_l1_thresh}, "
f"transformer_class={teacache_config.transformer_type}"
)
# Mark as enabled
self.enabled = True
def refresh(self, pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
"""
Refresh TeaCache state for new generation.
Clears all cached residuals and resets counters/accumulators.
Should be called before each generation to ensure clean state.
Args:
pipeline: Diffusion pipeline instance. Extracts transformer via pipeline.transformer.
num_inference_steps: Number of inference steps for the current generation.
Currently not used by TeaCache but accepted for interface consistency.
verbose: Whether to log refresh operations (default: True)
"""
# Extract transformer from pipeline
transformer = pipeline.transformer
if hasattr(transformer, "_hook_registry"):
hook = transformer._hook_registry.get_hook(TeaCacheHook._HOOK_NAME)
if hook is not None:
transformer._hook_registry.reset_hook(TeaCacheHook._HOOK_NAME)
if verbose:
logger.debug(f"TeaCache state refreshed (num_inference_steps={num_inference_steps})")
else:
if verbose:
logger.warning("TeaCache hook not found, nothing to refresh")
else:
if verbose:
logger.warning("Transformer has no hook registry, TeaCache may not be applied")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types
from typing import Any
import numpy as np
import torch
from vllm.config import LoadConfig
from vllm_omni.diffusion.cache.teacache.extractors import get_extractor
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline
class DataCollectionHook(ModelHook):
"""Hook to collect modulated inputs and model outputs for TeaCache coefficient estimation."""
_HOOK_NAME = "teacache_collector"
def __init__(self, transformer_type: str):
super().__init__()
self.transformer_type = transformer_type
self.extractor_fn = None
self.current_trajectory: list[tuple[np.ndarray, np.ndarray]] = []
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
self.extractor_fn = get_extractor(self.transformer_type)
return module
def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any:
ctx = self.extractor_fn(module, *args, **kwargs)
modulated_input_cpu = ctx.modulated_input.detach().cpu().numpy()
outputs = ctx.run_transformer_blocks()
ctx.hidden_states = outputs[0]
if len(outputs) > 1 and ctx.encoder_hidden_states is not None:
ctx.encoder_hidden_states = outputs[1]
model_output_cpu = ctx.hidden_states.detach().cpu().numpy()
self.current_trajectory.append((modulated_input_cpu, model_output_cpu))
return ctx.postprocess(ctx.hidden_states)
def start_collection(self):
self.current_trajectory = []
def stop_collection(self) -> list[tuple[np.ndarray, np.ndarray]]:
return list(self.current_trajectory)
class BagelAdapter:
"""Adapter for Bagel model."""
@staticmethod
def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> BagelPipeline:
od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
od_config.model_class_name = "BagelPipeline"
pipeline = BagelPipeline(od_config=od_config)
loader = DiffusersPipelineLoader(LoadConfig())
loader.load_weights(pipeline)
pipeline.to(device)
return pipeline
@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
return pipeline.bagel, "Bagel"
@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
original_forward_flow = transformer._forward_flow
def forward_alias(self, *args, **kwargs):
return original_forward_flow(*args, **kwargs)
transformer.forward = types.MethodType(forward_alias, transformer)
registry = HookRegistry.get_or_create(transformer)
registry.register_hook(hook._HOOK_NAME, hook)
transformer._forward_flow = transformer.forward
class DefaultAdapter:
"""Default adapter for standard diffusers pipelines."""
@staticmethod
def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any:
raise NotImplementedError("DefaultAdapter.load_pipeline not implemented")
@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
return pipeline.transformer, pipeline.transformer.__class__.__name__
@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry = HookRegistry.get_or_create(transformer)
registry.register_hook(hook._HOOK_NAME, hook)
_MODEL_ADAPTERS: dict[str, type] = {
"Bagel": BagelAdapter,
}
_EPSILON = 1e-6
def calculate_relative_l1(tensor_current: np.ndarray, tensor_next: np.ndarray) -> float:
"""Calculate relative L1 distance (Eq. 4 from TeaCache paper)."""
diff = np.abs(tensor_current - tensor_next).sum()
norm = np.abs(tensor_current).sum() + _EPSILON
return diff / norm
def estimate_teacache_coefficients(
collected_data: list[list[tuple[np.ndarray, np.ndarray]]], poly_order: int = 4
) -> list[float]:
"""Estimate polynomial coefficients for TeaCache using np.polyfit."""
input_diffs, output_diffs = [], []
for sample in collected_data:
for t in range(len(sample) - 1):
feat_in_curr, feat_out_curr = sample[t]
feat_in_next, feat_out_next = sample[t + 1]
input_diffs.append(calculate_relative_l1(feat_in_curr, feat_in_next))
output_diffs.append(calculate_relative_l1(feat_out_curr, feat_out_next))
x = np.array(input_diffs, dtype=np.float64)
y = np.array(output_diffs, dtype=np.float64)
print("Data statistics:")
print(f" Count: {len(x)}")
print(f" Input Diffs (x): min={x.min():.4e}, max={x.max():.4e}, mean={x.mean():.4e}")
print(f" Output Diffs (y): min={y.min():.4e}, max={y.max():.4e}, mean={y.mean():.4e}")
return np.polyfit(x, y, poly_order).tolist()
class TeaCacheCoefficientEstimator:
"""Model-agnostic helper class to collect data and estimate TeaCache coefficients."""
def __init__(
self,
model_path: str,
model_type: str = "Bagel",
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
# Add validation here ⬇️
if model_type not in _MODEL_ADAPTERS:
available_types = list(_MODEL_ADAPTERS.keys())
raise ValueError(
f"Unsupported model_type: '{model_type}'. "
f"Available types: {available_types}. "
f"To add support for a new model, add an entry to _MODEL_ADAPTERS."
)
adapter = _MODEL_ADAPTERS.get(model_type, DefaultAdapter)
self.pipeline = adapter.load_pipeline(model_path, device, dtype)
self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline)
self.hook = DataCollectionHook(self.transformer_type)
self.collected_data: list[list[tuple[np.ndarray, np.ndarray]]] = []
adapter.install_hook(self.transformer, self.hook)
def collect_from_prompt(self, prompt: str, **generate_kwargs):
self.hook.start_collection()
from vllm_omni.diffusion.request import OmniDiffusionRequest
req = OmniDiffusionRequest(
prompt=prompt,
num_inference_steps=generate_kwargs.get("num_inference_steps", 20),
seed=generate_kwargs.get("seed", 42),
)
self.pipeline.forward(req)
trajectory = self.hook.stop_collection()
if trajectory:
self.collected_data.append(trajectory)
def estimate(self, poly_order: int = 4) -> list[float]:
"""Estimate polynomial coefficients from collected data.
Args:
poly_order: Order of polynomial fit (default: 4)
Returns:
List of polynomial coefficients [a_n, a_{n-1}, ..., a_1, a_0]
Raises:
RuntimeError: If no data has been collected
"""
if not self.collected_data:
raise RuntimeError(
"No data collected for coefficient estimation. "
"Call collect_from_prompt() at least once before calling estimate()."
)
return estimate_teacache_coefficients(self.collected_data, poly_order)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment