# Copyright 2024 The Genmo team and The HuggingFace Team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import PeftAdapterMixin from diffusers.models.attention import FeedForward as HF_FeedForward from diffusers.models.attention_processor import Attention from diffusers.models.embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNormContinuous from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from diffusers.utils.torch_utils import maybe_allow_in_graph from liger_kernel.ops.swiglu import LigerSiLUMulFunction from fastvideo.models.flash_attn_no_pad import flash_attn_no_pad from fastvideo.models.mochi_hf.norm import (MochiLayerNormContinuous, MochiModulatedRMSNorm, MochiRMSNorm, MochiRMSNormZero) from fastvideo.utils.communications import all_gather, all_to_all_4D from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info logger = logging.get_logger(__name__) # pylint: disable=invalid-name class FeedForward(HF_FeedForward): def __init__( self, dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, inner_dim=None, bias: bool = True, ): super().__init__(dim, dim_out, mult, dropout, activation_fn, final_dropout, inner_dim, bias) assert activation_fn == "swiglu" def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.net[0].proj(hidden_states) hidden_states, gate = hidden_states.chunk(2, dim=-1) return self.net[2](LigerSiLUMulFunction.apply(gate, hidden_states)) class MochiAttention(nn.Module): def __init__( self, query_dim: int, processor: "MochiAttnProcessor2_0", heads: int = 8, dim_head: int = 64, dropout: float = 0.0, bias: bool = False, added_kv_proj_dim: Optional[int] = None, added_proj_bias: Optional[bool] = True, out_dim: int = None, out_context_dim: int = None, out_bias: bool = True, context_pre_only: bool = False, eps: float = 1e-5, ): super().__init__() self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.out_dim = out_dim if out_dim is not None else query_dim self.out_context_dim = out_context_dim if out_context_dim else query_dim self.context_pre_only = context_pre_only self.heads = out_dim // dim_head if out_dim is not None else heads self.norm_q = MochiRMSNorm(dim_head, eps) self.norm_k = MochiRMSNorm(dim_head, eps) self.norm_added_q = MochiRMSNorm(dim_head, eps) self.norm_added_k = MochiRMSNorm(dim_head, eps) self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) if self.context_pre_only is not None: self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) if not self.context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) self.processor = processor def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs, ): return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **kwargs, ) class MochiAttnProcessor2_0: """Attention processor used in Mochi.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: # [b, s, h * d] query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) # [b, s, h=24, d=128] query = query.unflatten(2, (attn.heads, -1)) key = key.unflatten(2, (attn.heads, -1)) value = value.unflatten(2, (attn.heads, -1)) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # [b, 256, h * d] encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) # [b, 256, h=24, d=128] encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) if attn.norm_added_q is not None: encoder_query = attn.norm_added_q(encoder_query) if attn.norm_added_k is not None: encoder_key = attn.norm_added_k(encoder_key) if image_rotary_emb is not None: freqs_cos, freqs_sin = image_rotary_emb[0], image_rotary_emb[1] # shard the head dimension if get_sequence_parallel_state(): # B, S, H, D to (S, B,) H, D # batch_size, seq_len, attn_heads, head_dim query = all_to_all_4D(query, scatter_dim=2, gather_dim=1) key = all_to_all_4D(key, scatter_dim=2, gather_dim=1) value = all_to_all_4D(value, scatter_dim=2, gather_dim=1) def shrink_head(encoder_state, dim): local_heads = encoder_state.shape[dim] // nccl_info.sp_size return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads) encoder_query = shrink_head(encoder_query, dim=2) encoder_key = shrink_head(encoder_key, dim=2) encoder_value = shrink_head(encoder_value, dim=2) if image_rotary_emb is not None: freqs_cos = shrink_head(freqs_cos, dim=1) freqs_sin = shrink_head(freqs_sin, dim=1) if image_rotary_emb is not None: def apply_rotary_emb(x, freqs_cos, freqs_sin): x_even = x[..., 0::2].float() x_odd = x[..., 1::2].float() cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) return torch.stack([cos, sin], dim=-1).flatten(-2) query = apply_rotary_emb(query, freqs_cos, freqs_sin) key = apply_rotary_emb(key, freqs_cos, freqs_sin) # query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) # encoder_query, encoder_key, encoder_value = ( # encoder_query.transpose(1, 2), # encoder_key.transpose(1, 2), # encoder_value.transpose(1, 2), # ) # [b, s, h, d] sequence_length = query.size(1) encoder_sequence_length = encoder_query.size(1) # H query = torch.cat([query, encoder_query], dim=1).unsqueeze(2) key = torch.cat([key, encoder_key], dim=1).unsqueeze(2) value = torch.cat([value, encoder_value], dim=1).unsqueeze(2) # B, S, 3, H, D qkv = torch.cat([query, key, value], dim=2) attn_mask = encoder_attention_mask[:, :].bool() attn_mask = F.pad(attn_mask, (sequence_length, 0), value=True) hidden_states = flash_attn_no_pad(qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None) # hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask = None, dropout_p=0.0, is_causal=False) # valid_lengths = encoder_attention_mask.sum(dim=1) + sequence_length # def no_padding_mask(score, b, h, q_idx, kv_idx): # return torch.where(kv_idx < valid_lengths[b],score, -float("inf")) # hidden_states = flex_attention(query, key, value, score_mod=no_padding_mask) if get_sequence_parallel_state(): hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( (sequence_length, encoder_sequence_length), dim=1) # B, S, H, D hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2) encoder_hidden_states = all_gather(encoder_hidden_states, dim=2).contiguous() hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) encoder_hidden_states = encoder_hidden_states.flatten(2, 3) encoder_hidden_states = encoder_hidden_states.to(query.dtype) else: hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( (sequence_length, encoder_sequence_length), dim=1) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if hasattr(attn, "to_add_out"): encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states @maybe_allow_in_graph class MochiTransformerBlock(nn.Module): r""" Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). Args: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. qk_norm (`str`, defaults to `"rms_norm"`): The normalization layer to use. activation_fn (`str`, defaults to `"swiglu"`): Activation function to use in feed-forward. context_pre_only (`bool`, defaults to `False`): Whether or not to process context-related conditions with additional layers. eps (`float`, defaults to `1e-6`): Epsilon value for normalization layers. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, pooled_projection_dim: int, qk_norm: str = "rms_norm", activation_fn: str = "swiglu", context_pre_only: bool = False, eps: float = 1e-6, ) -> None: super().__init__() self.context_pre_only = context_pre_only self.ff_inner_dim = (4 * dim * 2) // 3 self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3 self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False) if not context_pre_only: self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False) else: self.norm1_context = MochiLayerNormContinuous( embedding_dim=pooled_projection_dim, conditioning_embedding_dim=dim, eps=eps, ) self.attn1 = MochiAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, bias=False, added_kv_proj_dim=pooled_projection_dim, added_proj_bias=False, out_dim=dim, out_context_dim=pooled_projection_dim, context_pre_only=context_pre_only, processor=MochiAttnProcessor2_0(), eps=1e-5, ) # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True self.norm2 = MochiModulatedRMSNorm(eps=eps) self.norm2_context = (MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None) self.norm3 = MochiModulatedRMSNorm(eps) self.norm3_context = (MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None) self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) self.ff_context = None if not context_pre_only: self.ff_context = FeedForward( pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False, ) self.norm4 = MochiModulatedRMSNorm(eps=eps) self.norm4_context = MochiModulatedRMSNorm(eps=eps) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None, output_attn=False, ) -> Tuple[torch.Tensor, torch.Tensor]: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) if not self.context_pre_only: ( norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp, ) = self.norm1_context(encoder_hidden_states, temb) else: norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) attn_hidden_states, context_attn_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, encoder_attention_mask=encoder_attention_mask, ) hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)) norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32))) ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1)) if not self.context_pre_only: encoder_hidden_states = encoder_hidden_states + self.norm2_context(context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)) norm_encoder_hidden_states = self.norm3_context( encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32)), ) context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)) if not output_attn: attn_hidden_states = None return hidden_states, encoder_hidden_states, attn_hidden_states class MochiRoPE(nn.Module): r""" RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). Args: base_height (`int`, defaults to `192`): Base height used to compute interpolation scale for rotary positional embeddings. base_width (`int`, defaults to `192`): Base width used to compute interpolation scale for rotary positional embeddings. """ def __init__(self, base_height: int = 192, base_width: int = 192) -> None: super().__init__() self.target_area = base_height * base_width def _centers(self, start, stop, num, device, dtype) -> torch.Tensor: edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype) return (edges[:-1] + edges[1:]) / 2 def _get_positions( self, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: scale = (self.target_area / (height * width))**0.5 t = torch.arange(num_frames * nccl_info.sp_size, device=device, dtype=dtype) h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype) w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype) grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij") positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3) return positions def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: with torch.autocast(freqs.device.type, enabled=False): # Always run ROPE freqs computation in FP32 freqs = torch.einsum( "nd,dhf->nhf", # codespell:ignore pos.to(torch.float32), # codespell:ignore freqs.to(torch.float32)) freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) return freqs_cos, freqs_sin def forward( self, pos_frequencies: torch.Tensor, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: pos = self._get_positions(num_frames, height, width, device, dtype) rope_cos, rope_sin = self._create_rope(pos_frequencies, pos) return rope_cos, rope_sin @maybe_allow_in_graph class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): r""" A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview). Args: patch_size (`int`, defaults to `2`): The size of the patches to use in the patch embedding layer. num_attention_heads (`int`, defaults to `24`): The number of heads to use for multi-head attention. attention_head_dim (`int`, defaults to `128`): The number of channels in each head. num_layers (`int`, defaults to `48`): The number of layers of Transformer blocks to use. in_channels (`int`, defaults to `12`): The number of channels in the input. out_channels (`int`, *optional*, defaults to `None`): The number of channels in the output. qk_norm (`str`, defaults to `"rms_norm"`): The normalization layer to use. text_embed_dim (`int`, defaults to `4096`): Input dimension of text embeddings from the text encoder. time_embed_dim (`int`, defaults to `256`): Output dimension of timestep embeddings. activation_fn (`str`, defaults to `"swiglu"`): Activation function to use in feed-forward. max_sequence_length (`int`, defaults to `256`): The maximum sequence length of text embeddings supported. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, patch_size: int = 2, num_attention_heads: int = 24, attention_head_dim: int = 128, num_layers: int = 48, pooled_projection_dim: int = 1536, in_channels: int = 12, out_channels: Optional[int] = None, qk_norm: str = "rms_norm", text_embed_dim: int = 4096, time_embed_dim: int = 256, activation_fn: str = "swiglu", max_sequence_length: int = 256, ) -> None: super().__init__() inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels self.patch_embed = PatchEmbed( patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, pos_embed_type=None, ) self.time_embed = MochiCombinedTimestepCaptionEmbedding( embedding_dim=inner_dim, pooled_projection_dim=pooled_projection_dim, text_embed_dim=text_embed_dim, time_embed_dim=time_embed_dim, num_attention_heads=8, ) self.pos_frequencies = nn.Parameter(torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0)) self.rope = MochiRoPE() self.transformer_blocks = nn.ModuleList([ MochiTransformerBlock( dim=inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, pooled_projection_dim=pooled_projection_dim, qk_norm=qk_norm, activation_fn=activation_fn, context_pre_only=i == num_layers - 1, ) for i in range(num_layers) ]) self.norm_out = AdaLayerNormContinuous( inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm", ) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) self.gradient_checkpointing = False def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_attention_mask: torch.Tensor, output_features=False, output_features_stride=8, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = False, ) -> torch.Tensor: assert (return_dict is False), "return_dict is not supported in MochiTransformer3DModel" if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if (attention_kwargs is not None and attention_kwargs.get("scale", None) is not None): logger.warning("Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective.") batch_size, num_channels, num_frames, height, width = hidden_states.shape p = self.config.patch_size post_patch_height = height // p post_patch_width = width // p # Peiyuan: This is hacked to force mochi to follow the behaviour of SD3 and Flux timestep = 1000 - timestep temb, encoder_hidden_states = self.time_embed( timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype, ) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.patch_embed(hidden_states) hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) image_rotary_emb = self.rope( self.pos_frequencies, num_frames, post_patch_height, post_patch_width, device=hidden_states.device, dtype=torch.float32, ) attn_outputs_list = [] for i, block in enumerate(self.transformer_blocks): if self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = ({"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}) ( hidden_states, encoder_hidden_states, attn_outputs, ) = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, encoder_attention_mask, temb, image_rotary_emb, output_features, **ckpt_kwargs, ) else: hidden_states, encoder_hidden_states, attn_outputs = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, temb=temb, image_rotary_emb=image_rotary_emb, output_attn=output_features, ) if i % output_features_stride == 0: attn_outputs_list.append(attn_outputs) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1) hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) output = hidden_states.reshape(batch_size, -1, num_frames, height, width) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not output_features: attn_outputs_list = None else: attn_outputs_list = torch.stack(attn_outputs_list, dim=0) # Peiyuan: This is hacked to force mochi to follow the behaviour of SD3 and Flux return (-output, attn_outputs_list)