Unverified Commit 0434db9a authored by OleehyO's avatar OleehyO Committed by GitHub
Browse files

[cogview4][feat] Support attention mechanism with variable-length support and...


[cogview4][feat] Support attention mechanism with variable-length support and batch packing (#11349)

* [cogview4] Enhance attention mechanism with variable-length support and batch packing

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent aff574fb
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
......@@ -73,8 +73,9 @@ class CogView4AdaLayerNormZero(nn.Module):
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states = self.norm(hidden_states)
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states)
dtype = hidden_states.dtype
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
emb = self.linear(temb)
(
......@@ -111,8 +112,11 @@ class CogView4AdaLayerNormZero(nn.Module):
class CogView4AttnProcessor:
"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
"""
def __init__(self):
......@@ -125,8 +129,10 @@ class CogView4AttnProcessor:
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
dtype = encoder_hidden_states.dtype
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
......@@ -142,9 +148,9 @@ class CogView4AttnProcessor:
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
query = attn.norm_q(query).to(dtype=dtype)
if attn.norm_k is not None:
key = attn.norm_k(key)
key = attn.norm_k(key).to(dtype=dtype)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
......@@ -159,13 +165,14 @@ class CogView4AttnProcessor:
# 4. Attention
if attention_mask is not None:
text_attention_mask = attention_mask.float().to(query.device)
actual_text_seq_length = text_attention_mask.size(1)
new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
new_attention_mask = new_attention_mask.unsqueeze(2)
attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
text_attn_mask = attention_mask
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
text_attn_mask = text_attn_mask.float().to(query.device)
mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
mix_attn_mask[:, :text_seq_length] = text_attn_mask
mix_attn_mask = mix_attn_mask.unsqueeze(2)
attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
......@@ -183,9 +190,276 @@ class CogView4AttnProcessor:
return hidden_states, encoder_hidden_states
class CogView4TrainingAttnProcessor:
"""
Training Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary
embedding on query and key vectors, but does not include spatial normalization.
This processor differs from CogView4AttnProcessor in several important ways:
1. It supports attention masking with variable sequence lengths for multi-resolution training
2. It unpacks and repacks sequences for efficient training with variable sequence lengths when batch_flag is
provided
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
latent_attn_mask: Optional[torch.Tensor] = None,
text_attn_mask: Optional[torch.Tensor] = None,
batch_flag: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
attn (`Attention`):
The attention module.
hidden_states (`torch.Tensor`):
The input hidden states.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states for cross-attention.
latent_attn_mask (`torch.Tensor`, *optional*):
Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full
attention is used for all latent tokens. Note: the shape of latent_attn_mask is (batch_size,
num_latent_tokens).
text_attn_mask (`torch.Tensor`, *optional*):
Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention
is used for all text tokens.
batch_flag (`torch.Tensor`, *optional*):
Values from 0 to n-1 indicating which samples belong to the same batch. Samples with the same
batch_flag are packed together. Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form
batch1, and samples 3-4 form batch2. If None, no packing is used.
image_rotary_emb (`Tuple[torch.Tensor, torch.Tensor]` or `list[Tuple[torch.Tensor, torch.Tensor]]`, *optional*):
The rotary embedding for the image part of the input.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams.
"""
# Get dimensions and device info
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
dtype = encoder_hidden_states.dtype
device = encoder_hidden_states.device
latent_hidden_states = hidden_states
# Combine text and image streams for joint processing
mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1)
# 1. Construct attention mask and maybe packing input
# Create default masks if not provided
if text_attn_mask is None:
text_attn_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device)
if latent_attn_mask is None:
latent_attn_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device)
# Validate mask shapes and types
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
assert text_attn_mask.dtype == torch.int32, "the dtype of text_attn_mask should be torch.int32"
assert latent_attn_mask.dim() == 2, "the shape of latent_attn_mask should be (batch_size, num_latent_tokens)"
assert latent_attn_mask.dtype == torch.int32, "the dtype of latent_attn_mask should be torch.int32"
# Create combined mask for text and image tokens
mixed_attn_mask = torch.ones(
(batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device
)
mixed_attn_mask[:, :text_seq_length] = text_attn_mask
mixed_attn_mask[:, text_seq_length:] = latent_attn_mask
# Convert mask to attention matrix format (where 1 means attend, 0 means don't attend)
mixed_attn_mask_input = mixed_attn_mask.unsqueeze(2).to(dtype=dtype)
attn_mask_matrix = mixed_attn_mask_input @ mixed_attn_mask_input.transpose(1, 2)
# Handle batch packing if enabled
if batch_flag is not None:
assert batch_flag.dim() == 1
# Determine packed batch size based on batch_flag
packing_batch_size = torch.max(batch_flag).item() + 1
# Calculate actual sequence lengths for each sample based on masks
text_seq_length = torch.sum(text_attn_mask, dim=1)
latent_seq_length = torch.sum(latent_attn_mask, dim=1)
mixed_seq_length = text_seq_length + latent_seq_length
# Calculate packed sequence lengths for each packed batch
mixed_seq_length_packed = [
torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size)
]
assert len(mixed_seq_length_packed) == packing_batch_size
# Pack sequences by removing padding tokens
mixed_attn_mask_flatten = mixed_attn_mask.flatten(0, 1)
mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1)
mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attn_mask_flatten == 1]
assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0]
# Split the unpadded sequence into packed batches
mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed)
# Re-pad to create packed batches with right-side padding
mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence(
mixed_hidden_states_packed,
batch_first=True,
padding_value=0.0,
padding_side="right",
)
# Create attention mask for packed batches
l = mixed_hidden_states_packed_padded.shape[1]
attn_mask_matrix = torch.zeros(
(packing_batch_size, l, l),
dtype=dtype,
device=device,
)
# Fill attention mask with block diagonal matrices
# This ensures that tokens can only attend to other tokens within the same original sample
for idx, mask in enumerate(attn_mask_matrix):
seq_lengths = mixed_seq_length[batch_flag == idx]
offset = 0
for length in seq_lengths:
# Create a block of 1s for each sample in the packed batch
mask[offset : offset + length, offset : offset + length] = 1
offset += length
attn_mask_matrix = attn_mask_matrix.to(dtype=torch.bool)
attn_mask_matrix = attn_mask_matrix.unsqueeze(1) # Add attention head dim
attention_mask = attn_mask_matrix
# Prepare hidden states for attention computation
if batch_flag is None:
# If no packing, just combine text and image tokens
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
else:
# If packing, use the packed sequence
hidden_states = mixed_hidden_states_packed_padded
# 2. QKV projections - convert hidden states to query, key, value
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# Reshape for multi-head attention: [batch, seq_len, heads*dim] -> [batch, heads, seq_len, dim]
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 3. QK normalization - apply layer norm to queries and keys if configured
if attn.norm_q is not None:
query = attn.norm_q(query).to(dtype=dtype)
if attn.norm_k is not None:
key = attn.norm_k(key).to(dtype=dtype)
# 4. Apply rotary positional embeddings to image tokens only
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
if batch_flag is None:
# Apply RoPE only to image tokens (after text tokens)
query[:, :, text_seq_length:, :] = apply_rotary_emb(
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
key[:, :, text_seq_length:, :] = apply_rotary_emb(
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
else:
# For packed batches, need to carefully apply RoPE to appropriate tokens
assert query.shape[0] == packing_batch_size
assert key.shape[0] == packing_batch_size
assert len(image_rotary_emb) == batch_size
rope_idx = 0
for idx in range(packing_batch_size):
offset = 0
# Get text and image sequence lengths for samples in this packed batch
text_seq_length_bi = text_seq_length[batch_flag == idx]
latent_seq_length_bi = latent_seq_length[batch_flag == idx]
# Apply RoPE to each image segment in the packed sequence
for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi):
mlen = tlen + llen
# Apply RoPE only to image tokens (after text tokens)
query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
query[idx, :, offset + tlen : offset + mlen, :],
image_rotary_emb[rope_idx],
use_real_unbind_dim=-2,
)
key[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
key[idx, :, offset + tlen : offset + mlen, :],
image_rotary_emb[rope_idx],
use_real_unbind_dim=-2,
)
offset += mlen
rope_idx += 1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# Reshape back: [batch, heads, seq_len, dim] -> [batch, seq_len, heads*dim]
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
# 5. Output projection - project attention output to model dimension
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
# Split the output back into text and image streams
if batch_flag is None:
# Simple split for non-packed case
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
else:
# For packed case: need to unpack, split text/image, then restore to original shapes
# First, unpad the sequence based on the packed sequence lengths
hidden_states_unpad = torch.nn.utils.rnn.unpad_sequence(
hidden_states,
lengths=torch.tensor(mixed_seq_length_packed),
batch_first=True,
)
# Concatenate all unpadded sequences
hidden_states_flatten = torch.cat(hidden_states_unpad, dim=0)
# Split by original sample sequence lengths
hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist())
assert len(hidden_states_unpack) == batch_size
# Further split each sample's sequence into text and image parts
hidden_states_unpack = [
torch.split(h, [tlen, llen])
for h, tlen, llen in zip(hidden_states_unpack, text_seq_length, latent_seq_length)
]
# Separate text and image sequences
encoder_hidden_states_unpad = [h[0] for h in hidden_states_unpack]
hidden_states_unpad = [h[1] for h in hidden_states_unpack]
# Update the original tensors with the processed values, respecting the attention masks
for idx in range(batch_size):
# Place unpacked text tokens back in the encoder_hidden_states tensor
encoder_hidden_states[idx][text_attn_mask[idx] == 1] = encoder_hidden_states_unpad[idx]
# Place unpacked image tokens back in the latent_hidden_states tensor
latent_hidden_states[idx][latent_attn_mask[idx] == 1] = hidden_states_unpad[idx]
# Update the output hidden states
hidden_states = latent_hidden_states
return hidden_states, encoder_hidden_states
class CogView4TransformerBlock(nn.Module):
def __init__(
self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512
self,
dim: int = 2560,
num_attention_heads: int = 64,
attention_head_dim: int = 40,
time_embed_dim: int = 512,
) -> None:
super().__init__()
......@@ -213,9 +487,11 @@ class CogView4TransformerBlock(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
# 1. Timestep conditioning
(
......@@ -232,12 +508,14 @@ class CogView4TransformerBlock(nn.Module):
) = self.norm1(hidden_states, encoder_hidden_states, temb)
# 2. Attention
if attention_kwargs is None:
attention_kwargs = {}
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**kwargs,
**attention_kwargs,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
......@@ -402,7 +680,9 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
......@@ -422,7 +702,8 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
batch_size, num_channels, height, width = hidden_states.shape
# 1. RoPE
image_rotary_emb = self.rope(hidden_states)
if image_rotary_emb is None:
image_rotary_emb = self.rope(hidden_states)
# 2. Patch & Timestep embeddings
p = self.config.patch_size
......@@ -438,11 +719,22 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask,
attention_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask,
attention_kwargs,
)
# 4. Output norm & projection
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment