Unverified Commit 0434db9a authored by OleehyO's avatar OleehyO Committed by GitHub
Browse files

[cogview4][feat] Support attention mechanism with variable-length support and...


[cogview4][feat] Support attention mechanism with variable-length support and batch packing (#11349)

* [cogview4] Enhance attention mechanism with variable-length support and batch packing

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent aff574fb
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -73,8 +73,9 @@ class CogView4AdaLayerNormZero(nn.Module): ...@@ -73,8 +73,9 @@ class CogView4AdaLayerNormZero(nn.Module):
def forward( def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states = self.norm(hidden_states) dtype = hidden_states.dtype
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states) norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
emb = self.linear(temb) emb = self.linear(temb)
( (
...@@ -111,8 +112,11 @@ class CogView4AdaLayerNormZero(nn.Module): ...@@ -111,8 +112,11 @@ class CogView4AdaLayerNormZero(nn.Module):
class CogView4AttnProcessor: class CogView4AttnProcessor:
""" """
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization. query and key vectors, but does not include spatial normalization.
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
""" """
def __init__(self): def __init__(self):
...@@ -125,8 +129,10 @@ class CogView4AttnProcessor: ...@@ -125,8 +129,10 @@ class CogView4AttnProcessor:
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
dtype = encoder_hidden_states.dtype
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape batch_size, image_seq_length, embed_dim = hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
...@@ -142,9 +148,9 @@ class CogView4AttnProcessor: ...@@ -142,9 +148,9 @@ class CogView4AttnProcessor:
# 2. QK normalization # 2. QK normalization
if attn.norm_q is not None: if attn.norm_q is not None:
query = attn.norm_q(query) query = attn.norm_q(query).to(dtype=dtype)
if attn.norm_k is not None: if attn.norm_k is not None:
key = attn.norm_k(key) key = attn.norm_k(key).to(dtype=dtype)
# 3. Rotational positional embeddings applied to latent stream # 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None: if image_rotary_emb is not None:
...@@ -159,13 +165,14 @@ class CogView4AttnProcessor: ...@@ -159,13 +165,14 @@ class CogView4AttnProcessor:
# 4. Attention # 4. Attention
if attention_mask is not None: if attention_mask is not None:
text_attention_mask = attention_mask.float().to(query.device) text_attn_mask = attention_mask
actual_text_seq_length = text_attention_mask.size(1) assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device) text_attn_mask = text_attn_mask.float().to(query.device)
new_attention_mask[:, :actual_text_seq_length] = text_attention_mask mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
new_attention_mask = new_attention_mask.unsqueeze(2) mix_attn_mask[:, :text_seq_length] = text_attn_mask
attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2) mix_attn_mask = mix_attn_mask.unsqueeze(2)
attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype) attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
hidden_states = F.scaled_dot_product_attention( hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
...@@ -183,9 +190,276 @@ class CogView4AttnProcessor: ...@@ -183,9 +190,276 @@ class CogView4AttnProcessor:
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
class CogView4TrainingAttnProcessor:
"""
Training Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary
embedding on query and key vectors, but does not include spatial normalization.
This processor differs from CogView4AttnProcessor in several important ways:
1. It supports attention masking with variable sequence lengths for multi-resolution training
2. It unpacks and repacks sequences for efficient training with variable sequence lengths when batch_flag is
provided
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
latent_attn_mask: Optional[torch.Tensor] = None,
text_attn_mask: Optional[torch.Tensor] = None,
batch_flag: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
attn (`Attention`):
The attention module.
hidden_states (`torch.Tensor`):
The input hidden states.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states for cross-attention.
latent_attn_mask (`torch.Tensor`, *optional*):
Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full
attention is used for all latent tokens. Note: the shape of latent_attn_mask is (batch_size,
num_latent_tokens).
text_attn_mask (`torch.Tensor`, *optional*):
Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention
is used for all text tokens.
batch_flag (`torch.Tensor`, *optional*):
Values from 0 to n-1 indicating which samples belong to the same batch. Samples with the same
batch_flag are packed together. Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form
batch1, and samples 3-4 form batch2. If None, no packing is used.
image_rotary_emb (`Tuple[torch.Tensor, torch.Tensor]` or `list[Tuple[torch.Tensor, torch.Tensor]]`, *optional*):
The rotary embedding for the image part of the input.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams.
"""
# Get dimensions and device info
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
dtype = encoder_hidden_states.dtype
device = encoder_hidden_states.device
latent_hidden_states = hidden_states
# Combine text and image streams for joint processing
mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1)
# 1. Construct attention mask and maybe packing input
# Create default masks if not provided
if text_attn_mask is None:
text_attn_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device)
if latent_attn_mask is None:
latent_attn_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device)
# Validate mask shapes and types
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
assert text_attn_mask.dtype == torch.int32, "the dtype of text_attn_mask should be torch.int32"
assert latent_attn_mask.dim() == 2, "the shape of latent_attn_mask should be (batch_size, num_latent_tokens)"
assert latent_attn_mask.dtype == torch.int32, "the dtype of latent_attn_mask should be torch.int32"
# Create combined mask for text and image tokens
mixed_attn_mask = torch.ones(
(batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device
)
mixed_attn_mask[:, :text_seq_length] = text_attn_mask
mixed_attn_mask[:, text_seq_length:] = latent_attn_mask
# Convert mask to attention matrix format (where 1 means attend, 0 means don't attend)
mixed_attn_mask_input = mixed_attn_mask.unsqueeze(2).to(dtype=dtype)
attn_mask_matrix = mixed_attn_mask_input @ mixed_attn_mask_input.transpose(1, 2)
# Handle batch packing if enabled
if batch_flag is not None:
assert batch_flag.dim() == 1
# Determine packed batch size based on batch_flag
packing_batch_size = torch.max(batch_flag).item() + 1
# Calculate actual sequence lengths for each sample based on masks
text_seq_length = torch.sum(text_attn_mask, dim=1)
latent_seq_length = torch.sum(latent_attn_mask, dim=1)
mixed_seq_length = text_seq_length + latent_seq_length
# Calculate packed sequence lengths for each packed batch
mixed_seq_length_packed = [
torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size)
]
assert len(mixed_seq_length_packed) == packing_batch_size
# Pack sequences by removing padding tokens
mixed_attn_mask_flatten = mixed_attn_mask.flatten(0, 1)
mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1)
mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attn_mask_flatten == 1]
assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0]
# Split the unpadded sequence into packed batches
mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed)
# Re-pad to create packed batches with right-side padding
mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence(
mixed_hidden_states_packed,
batch_first=True,
padding_value=0.0,
padding_side="right",
)
# Create attention mask for packed batches
l = mixed_hidden_states_packed_padded.shape[1]
attn_mask_matrix = torch.zeros(
(packing_batch_size, l, l),
dtype=dtype,
device=device,
)
# Fill attention mask with block diagonal matrices
# This ensures that tokens can only attend to other tokens within the same original sample
for idx, mask in enumerate(attn_mask_matrix):
seq_lengths = mixed_seq_length[batch_flag == idx]
offset = 0
for length in seq_lengths:
# Create a block of 1s for each sample in the packed batch
mask[offset : offset + length, offset : offset + length] = 1
offset += length
attn_mask_matrix = attn_mask_matrix.to(dtype=torch.bool)
attn_mask_matrix = attn_mask_matrix.unsqueeze(1) # Add attention head dim
attention_mask = attn_mask_matrix
# Prepare hidden states for attention computation
if batch_flag is None:
# If no packing, just combine text and image tokens
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
else:
# If packing, use the packed sequence
hidden_states = mixed_hidden_states_packed_padded
# 2. QKV projections - convert hidden states to query, key, value
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# Reshape for multi-head attention: [batch, seq_len, heads*dim] -> [batch, heads, seq_len, dim]
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 3. QK normalization - apply layer norm to queries and keys if configured
if attn.norm_q is not None:
query = attn.norm_q(query).to(dtype=dtype)
if attn.norm_k is not None:
key = attn.norm_k(key).to(dtype=dtype)
# 4. Apply rotary positional embeddings to image tokens only
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
if batch_flag is None:
# Apply RoPE only to image tokens (after text tokens)
query[:, :, text_seq_length:, :] = apply_rotary_emb(
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
key[:, :, text_seq_length:, :] = apply_rotary_emb(
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
else:
# For packed batches, need to carefully apply RoPE to appropriate tokens
assert query.shape[0] == packing_batch_size
assert key.shape[0] == packing_batch_size
assert len(image_rotary_emb) == batch_size
rope_idx = 0
for idx in range(packing_batch_size):
offset = 0
# Get text and image sequence lengths for samples in this packed batch
text_seq_length_bi = text_seq_length[batch_flag == idx]
latent_seq_length_bi = latent_seq_length[batch_flag == idx]
# Apply RoPE to each image segment in the packed sequence
for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi):
mlen = tlen + llen
# Apply RoPE only to image tokens (after text tokens)
query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
query[idx, :, offset + tlen : offset + mlen, :],
image_rotary_emb[rope_idx],
use_real_unbind_dim=-2,
)
key[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
key[idx, :, offset + tlen : offset + mlen, :],
image_rotary_emb[rope_idx],
use_real_unbind_dim=-2,
)
offset += mlen
rope_idx += 1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# Reshape back: [batch, heads, seq_len, dim] -> [batch, seq_len, heads*dim]
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
# 5. Output projection - project attention output to model dimension
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
# Split the output back into text and image streams
if batch_flag is None:
# Simple split for non-packed case
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
else:
# For packed case: need to unpack, split text/image, then restore to original shapes
# First, unpad the sequence based on the packed sequence lengths
hidden_states_unpad = torch.nn.utils.rnn.unpad_sequence(
hidden_states,
lengths=torch.tensor(mixed_seq_length_packed),
batch_first=True,
)
# Concatenate all unpadded sequences
hidden_states_flatten = torch.cat(hidden_states_unpad, dim=0)
# Split by original sample sequence lengths
hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist())
assert len(hidden_states_unpack) == batch_size
# Further split each sample's sequence into text and image parts
hidden_states_unpack = [
torch.split(h, [tlen, llen])
for h, tlen, llen in zip(hidden_states_unpack, text_seq_length, latent_seq_length)
]
# Separate text and image sequences
encoder_hidden_states_unpad = [h[0] for h in hidden_states_unpack]
hidden_states_unpad = [h[1] for h in hidden_states_unpack]
# Update the original tensors with the processed values, respecting the attention masks
for idx in range(batch_size):
# Place unpacked text tokens back in the encoder_hidden_states tensor
encoder_hidden_states[idx][text_attn_mask[idx] == 1] = encoder_hidden_states_unpad[idx]
# Place unpacked image tokens back in the latent_hidden_states tensor
latent_hidden_states[idx][latent_attn_mask[idx] == 1] = hidden_states_unpad[idx]
# Update the output hidden states
hidden_states = latent_hidden_states
return hidden_states, encoder_hidden_states
class CogView4TransformerBlock(nn.Module): class CogView4TransformerBlock(nn.Module):
def __init__( def __init__(
self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512 self,
dim: int = 2560,
num_attention_heads: int = 64,
attention_head_dim: int = 40,
time_embed_dim: int = 512,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -213,9 +487,11 @@ class CogView4TransformerBlock(nn.Module): ...@@ -213,9 +487,11 @@ class CogView4TransformerBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[
attention_mask: Optional[torch.Tensor] = None, Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
**kwargs, ] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# 1. Timestep conditioning # 1. Timestep conditioning
( (
...@@ -232,12 +508,14 @@ class CogView4TransformerBlock(nn.Module): ...@@ -232,12 +508,14 @@ class CogView4TransformerBlock(nn.Module):
) = self.norm1(hidden_states, encoder_hidden_states, temb) ) = self.norm1(hidden_states, encoder_hidden_states, temb)
# 2. Attention # 2. Attention
if attention_kwargs is None:
attention_kwargs = {}
attn_hidden_states, attn_encoder_hidden_states = self.attn1( attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask, attention_mask=attention_mask,
**kwargs, **attention_kwargs,
) )
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
...@@ -402,7 +680,9 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach ...@@ -402,7 +680,9 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
**kwargs, image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
) -> Union[torch.Tensor, Transformer2DModelOutput]: ) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None: if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy() attention_kwargs = attention_kwargs.copy()
...@@ -422,6 +702,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach ...@@ -422,6 +702,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
batch_size, num_channels, height, width = hidden_states.shape batch_size, num_channels, height, width = hidden_states.shape
# 1. RoPE # 1. RoPE
if image_rotary_emb is None:
image_rotary_emb = self.rope(hidden_states) image_rotary_emb = self.rope(hidden_states)
# 2. Patch & Timestep embeddings # 2. Patch & Timestep embeddings
...@@ -438,11 +719,22 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach ...@@ -438,11 +719,22 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
for block in self.transformer_blocks: for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask,
attention_kwargs,
) )
else: else:
hidden_states, encoder_hidden_states = block( hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask,
attention_kwargs,
) )
# 4. Output norm & projection # 4. Output norm & projection
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment