Unverified Commit 3f329a42 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] Mochi T2V (#9769)



* update

* udpate

* update transformer

* make style

* fix

* add conversion script

* update

* fix

* update

* fix

* update

* fixes

* make style

* update

* update

* update

* init

* update

* update

* add

* up

* up

* up

* update

* mochi transformer

* remove original implementation

* make style

* update inits

* update conversion script

* docs

* Update src/diffusers/pipelines/mochi/pipeline_mochi.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* Update src/diffusers/pipelines/mochi/pipeline_mochi.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* fix docs

* pipeline fixes

* make style

* invert sigmas in scheduler; fix pipeline

* fix pipeline num_frames

* flip proj and gate in swiglu

* make style

* fix

* make style

* fix tests

* latent mean and std fix

* update

* cherry-pick 1069d210e1b9e84a366cdc7a13965626ea258178

* remove additional sigma already handled by flow match scheduler

* fix

* remove hardcoded value

* replace conv1x1 with linear

* Update src/diffusers/pipelines/mochi/pipeline_mochi.py
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* framewise decoding and conv_cache

* make style

* Apply suggestions from code review

* mochi vae encoder changes

* rebase correctly

* Update scripts/convert_mochi_to_diffusers.py

* fix tests

* fixes

* make style

* update

* make style

* update

* add framewise and tiled encoding

* make style

* make original vae implementation behaviour the default; note: framewise encoding does not work

* remove framewise encoding implementation due to presence of attn layers

* fight test 1

* fight test 2

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
parent a3cc641f
...@@ -270,6 +270,8 @@ ...@@ -270,6 +270,8 @@
title: LatteTransformer3DModel title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d - local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel title: LuminaNextDiT2DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
- local: api/models/pixart_transformer2d - local: api/models/pixart_transformer2d
title: PixArtTransformer2DModel title: PixArtTransformer2DModel
- local: api/models/prior_transformer - local: api/models/prior_transformer
...@@ -306,6 +308,8 @@ ...@@ -306,6 +308,8 @@
title: AutoencoderKLAllegro title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox - local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_mochi
title: AutoencoderKLMochi
- local: api/models/asymmetricautoencoderkl - local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL title: AsymmetricAutoencoderKL
- local: api/models/consistency_decoder_vae - local: api/models/consistency_decoder_vae
...@@ -400,6 +404,8 @@ ...@@ -400,6 +404,8 @@
title: Lumina-T2X title: Lumina-T2X
- local: api/pipelines/marigold - local: api/pipelines/marigold
title: Marigold title: Marigold
- local: api/pipelines/mochi
title: Mochi
- local: api/pipelines/panorama - local: api/pipelines/panorama
title: MultiDiffusion title: MultiDiffusion
- local: api/pipelines/musicldm - local: api/pipelines/musicldm
......
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLMochi
The 3D variational autoencoder (VAE) model with KL loss used in [Mochi](https://github.com/genmoai/models) was introduced in [Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Tsinghua University & ZhipuAI.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLMochi
vae = AutoencoderKLMochi.from_pretrained("genmo/mochi-1-preview", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLMochi
[[autodoc]] AutoencoderKLMochi
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# MochiTransformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [Mochi-1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Genmo.
The model can be loaded with the following code snippet.
```python
from diffusers import MochiTransformer3DModel
vae = MochiTransformer3DModel.from_pretrained("genmo/mochi-1-preview", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
```
## MochiTransformer3DModel
[[autodoc]] MochiTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# Mochi
[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo.
*Mochi 1 preview is an open state-of-the-art video generation model with high-fidelity motion and strong prompt adherence in preliminary evaluation. This model dramatically closes the gap between closed and open video generation systems. The model is released under a permissive Apache 2.0 license.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## MochiPipeline
[[autodoc]] MochiPipeline
- all
- __call__
## MochiPipelineOutput
[[autodoc]] pipelines.mochi.pipeline_output.MochiPipelineOutput
This diff is collapsed.
...@@ -83,6 +83,7 @@ else: ...@@ -83,6 +83,7 @@ else:
"AutoencoderKL", "AutoencoderKL",
"AutoencoderKLAllegro", "AutoencoderKLAllegro",
"AutoencoderKLCogVideoX", "AutoencoderKLCogVideoX",
"AutoencoderKLMochi",
"AutoencoderKLTemporalDecoder", "AutoencoderKLTemporalDecoder",
"AutoencoderOobleck", "AutoencoderOobleck",
"AutoencoderTiny", "AutoencoderTiny",
...@@ -102,6 +103,7 @@ else: ...@@ -102,6 +103,7 @@ else:
"Kandinsky3UNet", "Kandinsky3UNet",
"LatteTransformer3DModel", "LatteTransformer3DModel",
"LuminaNextDiT2DModel", "LuminaNextDiT2DModel",
"MochiTransformer3DModel",
"ModelMixin", "ModelMixin",
"MotionAdapter", "MotionAdapter",
"MultiAdapter", "MultiAdapter",
...@@ -311,6 +313,7 @@ else: ...@@ -311,6 +313,7 @@ else:
"LuminaText2ImgPipeline", "LuminaText2ImgPipeline",
"MarigoldDepthPipeline", "MarigoldDepthPipeline",
"MarigoldNormalsPipeline", "MarigoldNormalsPipeline",
"MochiPipeline",
"MusicLDMPipeline", "MusicLDMPipeline",
"PaintByExamplePipeline", "PaintByExamplePipeline",
"PIAPipeline", "PIAPipeline",
...@@ -565,6 +568,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -565,6 +568,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKL, AutoencoderKL,
AutoencoderKLAllegro, AutoencoderKLAllegro,
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder, AutoencoderKLTemporalDecoder,
AutoencoderOobleck, AutoencoderOobleck,
AutoencoderTiny, AutoencoderTiny,
...@@ -584,6 +588,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -584,6 +588,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky3UNet, Kandinsky3UNet,
LatteTransformer3DModel, LatteTransformer3DModel,
LuminaNextDiT2DModel, LuminaNextDiT2DModel,
MochiTransformer3DModel,
ModelMixin, ModelMixin,
MotionAdapter, MotionAdapter,
MultiAdapter, MultiAdapter,
...@@ -772,6 +777,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -772,6 +777,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LuminaText2ImgPipeline, LuminaText2ImgPipeline,
MarigoldDepthPipeline, MarigoldDepthPipeline,
MarigoldNormalsPipeline, MarigoldNormalsPipeline,
MochiPipeline,
MusicLDMPipeline, MusicLDMPipeline,
PaintByExamplePipeline, PaintByExamplePipeline,
PIAPipeline, PIAPipeline,
......
...@@ -30,6 +30,7 @@ if is_torch_available(): ...@@ -30,6 +30,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
...@@ -58,6 +59,7 @@ if is_torch_available(): ...@@ -58,6 +59,7 @@ if is_torch_available():
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"]
...@@ -85,6 +87,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -85,6 +87,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKL, AutoencoderKL,
AutoencoderKLAllegro, AutoencoderKLAllegro,
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder, AutoencoderKLTemporalDecoder,
AutoencoderOobleck, AutoencoderOobleck,
AutoencoderTiny, AutoencoderTiny,
...@@ -110,6 +113,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -110,6 +113,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DModel, HunyuanDiT2DModel,
LatteTransformer3DModel, LatteTransformer3DModel,
LuminaNextDiT2DModel, LuminaNextDiT2DModel,
MochiTransformer3DModel,
PixArtTransformer2DModel, PixArtTransformer2DModel,
PriorTransformer, PriorTransformer,
SD3Transformer2DModel, SD3Transformer2DModel,
......
...@@ -136,6 +136,7 @@ class SwiGLU(nn.Module): ...@@ -136,6 +136,7 @@ class SwiGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True): def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__() super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
self.activation = nn.SiLU() self.activation = nn.SiLU()
......
...@@ -120,14 +120,16 @@ class Attention(nn.Module): ...@@ -120,14 +120,16 @@ class Attention(nn.Module):
_from_deprecated_attn_block: bool = False, _from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None, processor: Optional["AttnProcessor"] = None,
out_dim: int = None, out_dim: int = None,
out_context_dim: int = None,
context_pre_only=None, context_pre_only=None,
pre_only=False, pre_only=False,
elementwise_affine: bool = True, elementwise_affine: bool = True,
is_causal: bool = False,
): ):
super().__init__() super().__init__()
# To prevent circular import. # To prevent circular import.
from .normalization import FP32LayerNorm, RMSNorm from .normalization import FP32LayerNorm, LpNorm, RMSNorm
self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
...@@ -142,8 +144,10 @@ class Attention(nn.Module): ...@@ -142,8 +144,10 @@ class Attention(nn.Module):
self.dropout = dropout self.dropout = dropout
self.fused_projections = False self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.context_pre_only = context_pre_only self.context_pre_only = context_pre_only
self.pre_only = pre_only self.pre_only = pre_only
self.is_causal = is_causal
# we make use of this private variable to know whether this class is loaded # we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly # with an deprecated state dict so that we can convert it on the fly
...@@ -192,6 +196,9 @@ class Attention(nn.Module): ...@@ -192,6 +196,9 @@ class Attention(nn.Module):
elif qk_norm == "rms_norm": elif qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps)
elif qk_norm == "l2":
self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
else: else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'") raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
...@@ -241,7 +248,7 @@ class Attention(nn.Module): ...@@ -241,7 +248,7 @@ class Attention(nn.Module):
self.to_out.append(nn.Dropout(dropout)) self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only: if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
if qk_norm is not None and added_kv_proj_dim is not None: if qk_norm is not None and added_kv_proj_dim is not None:
if qk_norm == "fp32_layer_norm": if qk_norm == "fp32_layer_norm":
...@@ -1886,6 +1893,7 @@ class FluxAttnProcessor2_0: ...@@ -1886,6 +1893,7 @@ class FluxAttnProcessor2_0:
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
...@@ -2714,6 +2722,91 @@ class AttnProcessor2_0: ...@@ -2714,6 +2722,91 @@ class AttnProcessor2_0:
return hidden_states return hidden_states
class MochiVaeAttnProcessor2_0:
r"""
Attention processor used in Mochi VAE.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states
is_single_frame = hidden_states.shape[1] == 1
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if is_single_frame:
hidden_states = attn.to_v(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class StableAudioAttnProcessor2_0: class StableAudioAttnProcessor2_0:
r""" r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
...@@ -3389,6 +3482,94 @@ class LuminaAttnProcessor2_0: ...@@ -3389,6 +3482,94 @@ class LuminaAttnProcessor2_0:
return hidden_states return hidden_states
class MochiAttnProcessor2_0:
"""Attention processor used in Mochi."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
if image_rotary_emb is not None:
def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
return torch.stack([cos, sin], dim=-1).flatten(-2)
query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)
sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)
query = torch.cat([query, encoder_query], dim=2)
key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class FusedAttnProcessor2_0: class FusedAttnProcessor2_0:
r""" r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
......
...@@ -2,6 +2,7 @@ from .autoencoder_asym_kl import AsymmetricAutoencoderKL ...@@ -2,6 +2,7 @@ from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_oobleck import AutoencoderOobleck
from .autoencoder_tiny import AutoencoderTiny from .autoencoder_tiny import AutoencoderTiny
......
...@@ -94,11 +94,13 @@ class CogVideoXCausalConv3d(nn.Module): ...@@ -94,11 +94,13 @@ class CogVideoXCausalConv3d(nn.Module):
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.pad_mode = pad_mode # TODO(aryan): configure calculation based on stride and dilation in the future.
time_pad = dilation * (time_kernel_size - 1) + (1 - stride) # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
height_pad = height_kernel_size // 2 time_pad = time_kernel_size - 1
width_pad = width_kernel_size // 2 height_pad = (height_kernel_size - 1) // 2
width_pad = (width_kernel_size - 1) // 2
self.pad_mode = pad_mode
self.height_pad = height_pad self.height_pad = height_pad
self.width_pad = width_pad self.width_pad = width_pad
self.time_pad = time_pad self.time_pad = time_pad
...@@ -107,7 +109,7 @@ class CogVideoXCausalConv3d(nn.Module): ...@@ -107,7 +109,7 @@ class CogVideoXCausalConv3d(nn.Module):
self.temporal_dim = 2 self.temporal_dim = 2
self.time_kernel_size = time_kernel_size self.time_kernel_size = time_kernel_size
stride = (stride, 1, 1) stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
dilation = (dilation, 1, 1) dilation = (dilation, 1, 1)
self.conv = CogVideoXSafeConv3d( self.conv = CogVideoXSafeConv3d(
in_channels=in_channels, in_channels=in_channels,
...@@ -120,18 +122,24 @@ class CogVideoXCausalConv3d(nn.Module): ...@@ -120,18 +122,24 @@ class CogVideoXCausalConv3d(nn.Module):
def fake_context_parallel_forward( def fake_context_parallel_forward(
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
kernel_size = self.time_kernel_size if self.pad_mode == "replicate":
if kernel_size > 1: inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) else:
inputs = torch.cat(cached_inputs + [inputs], dim=2) kernel_size = self.time_kernel_size
if kernel_size > 1:
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
inputs = torch.cat(cached_inputs + [inputs], dim=2)
return inputs return inputs
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs, conv_cache) inputs = self.fake_context_parallel_forward(inputs, conv_cache)
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) if self.pad_mode == "replicate":
inputs = F.pad(inputs, padding_2d, mode="constant", value=0) conv_cache = None
else:
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
output = self.conv(inputs) output = self.conv(inputs)
return output, conv_cache return output, conv_cache
......
This diff is collapsed.
...@@ -1356,6 +1356,41 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module): ...@@ -1356,6 +1356,41 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
return conditioning return conditioning
class MochiCombinedTimestepCaptionEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
pooled_projection_dim: int,
text_embed_dim: int,
time_embed_dim: int = 256,
num_attention_heads: int = 8,
) -> None:
super().__init__()
self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim)
self.pooler = MochiAttentionPool(
num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim
)
self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim)
def forward(
self,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
hidden_dtype: Optional[torch.dtype] = None,
):
time_proj = self.time_proj(timestep)
time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype))
pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask)
caption_proj = self.caption_proj(encoder_hidden_states)
conditioning = time_emb + pooled_projections
return conditioning, caption_proj
class TextTimeEmbedding(nn.Module): class TextTimeEmbedding(nn.Module):
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
super().__init__() super().__init__()
...@@ -1484,6 +1519,88 @@ class AttentionPooling(nn.Module): ...@@ -1484,6 +1519,88 @@ class AttentionPooling(nn.Module):
return a[:, 0, :] # cls_token return a[:, 0, :] # cls_token
class MochiAttentionPool(nn.Module):
def __init__(
self,
num_attention_heads: int,
embed_dim: int,
output_dim: Optional[int] = None,
) -> None:
super().__init__()
self.output_dim = output_dim or embed_dim
self.num_attention_heads = num_attention_heads
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim)
self.to_q = nn.Linear(embed_dim, embed_dim)
self.to_out = nn.Linear(embed_dim, self.output_dim)
@staticmethod
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
"""
Pool tokens in x using mask.
NOTE: We assume x does not require gradients.
Args:
x: (B, L, D) tensor of tokens.
mask: (B, L) boolean tensor indicating which tokens are not padding.
Returns:
pooled: (B, D) tensor of pooled tokens.
"""
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
mask = mask[:, :, None].to(dtype=x.dtype)
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
return pooled
def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
r"""
Args:
x (`torch.Tensor`):
Tensor of shape `(B, S, D)` of input tokens.
mask (`torch.Tensor`):
Boolean ensor of shape `(B, S)` indicating which tokens are not padding.
Returns:
`torch.Tensor`:
`(B, D)` tensor of pooled tokens.
"""
D = x.size(2)
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
# Average non-padding token features. These will be used as the query.
x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D)
# Concat pooled features to input sequence.
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
# Compute queries, keys, values. Only the mean token is used to create a query.
kv = self.to_kv(x) # (B, L+1, 2 * D)
q = self.to_q(x[:, 0]) # (B, D)
# Extract heads.
head_dim = D // self.num_attention_heads
kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim)
q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention.
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x)
return x
def get_fourier_embeds_from_boundingbox(embed_dim, box): def get_fourier_embeds_from_boundingbox(embed_dim, box):
""" """
Args: Args:
......
...@@ -234,6 +234,33 @@ class LuminaRMSNormZero(nn.Module): ...@@ -234,6 +234,33 @@ class LuminaRMSNormZero(nn.Module):
return x, gate_msa, scale_mlp, gate_mlp return x, gate_msa, scale_mlp, gate_mlp
class MochiRMSNormZero(nn.Module):
r"""
Adaptive RMS Norm used in Mochi.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""
def __init__(
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, hidden_dim)
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(
self, hidden_states: torch.Tensor, emb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])
return hidden_states, gate_msa, scale_mlp, gate_mlp
class AdaLayerNormSingle(nn.Module): class AdaLayerNormSingle(nn.Module):
r""" r"""
Norm layer adaptive layer norm single (adaLN-single). Norm layer adaptive layer norm single (adaLN-single).
...@@ -356,20 +383,21 @@ class LuminaLayerNormContinuous(nn.Module): ...@@ -356,20 +383,21 @@ class LuminaLayerNormContinuous(nn.Module):
out_dim: Optional[int] = None, out_dim: Optional[int] = None,
): ):
super().__init__() super().__init__()
# AdaLN # AdaLN
self.silu = nn.SiLU() self.silu = nn.SiLU()
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
if norm_type == "layer_norm": if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
else: else:
raise ValueError(f"unknown norm_type {norm_type}") raise ValueError(f"unknown norm_type {norm_type}")
# linear_2
self.linear_2 = None
if out_dim is not None: if out_dim is not None:
self.linear_2 = nn.Linear( self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
embedding_dim,
out_dim,
bias=bias,
)
def forward( def forward(
self, self,
...@@ -526,3 +554,15 @@ class GlobalResponseNorm(nn.Module): ...@@ -526,3 +554,15 @@ class GlobalResponseNorm(nn.Module):
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * nx) + self.beta + x return self.gamma * (x * nx) + self.beta + x
class LpNorm(nn.Module):
def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12):
super().__init__()
self.p = p
self.dim = dim
self.eps = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
...@@ -17,5 +17,6 @@ if is_torch_available(): ...@@ -17,5 +17,6 @@ if is_torch_available():
from .transformer_allegro import AllegroTransformer3DModel from .transformer_allegro import AllegroTransformer3DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_flux import FluxTransformer2DModel from .transformer_flux import FluxTransformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel from .transformer_temporal import TransformerTemporalModel
# Copyright 2024 The Genmo team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention, MochiAttnProcessor2_0
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class MochiTransformerBlock(nn.Module):
r"""
Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
Args:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
qk_norm (`str`, defaults to `"rms_norm"`):
The normalization layer to use.
activation_fn (`str`, defaults to `"swiglu"`):
Activation function to use in feed-forward.
context_pre_only (`bool`, defaults to `False`):
Whether or not to process context-related conditions with additional layers.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
pooled_projection_dim: int,
qk_norm: str = "rms_norm",
activation_fn: str = "swiglu",
context_pre_only: bool = False,
eps: float = 1e-6,
) -> None:
super().__init__()
self.context_pre_only = context_pre_only
self.ff_inner_dim = (4 * dim * 2) // 3
self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3
self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False)
if not context_pre_only:
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
else:
self.norm1_context = LuminaLayerNormContinuous(
embedding_dim=pooled_projection_dim,
conditioning_embedding_dim=dim,
eps=eps,
elementwise_affine=False,
norm_type="rms_norm",
out_dim=None,
)
self.attn1 = Attention(
query_dim=dim,
cross_attention_dim=None,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=False,
qk_norm=qk_norm,
added_kv_proj_dim=pooled_projection_dim,
added_proj_bias=False,
out_dim=dim,
out_context_dim=pooled_projection_dim,
context_pre_only=context_pre_only,
processor=MochiAttnProcessor2_0(),
eps=eps,
elementwise_affine=True,
)
# TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
self.ff_context = None
if not context_pre_only:
self.ff_context = FeedForward(
pooled_projection_dim,
inner_dim=self.ff_context_inner_dim,
activation_fn=activation_fn,
bias=False,
)
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
if not self.context_pre_only:
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
encoder_hidden_states, temb
)
else:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
attn_hidden_states, context_attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1)
if not self.context_pre_only:
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
context_attn_hidden_states
) * torch.tanh(enc_gate_msa).unsqueeze(1)
norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(
enc_gate_mlp
).unsqueeze(1)
return hidden_states, encoder_hidden_states
class MochiRoPE(nn.Module):
r"""
RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
Args:
base_height (`int`, defaults to `192`):
Base height used to compute interpolation scale for rotary positional embeddings.
base_width (`int`, defaults to `192`):
Base width used to compute interpolation scale for rotary positional embeddings.
"""
def __init__(self, base_height: int = 192, base_width: int = 192) -> None:
super().__init__()
self.target_area = base_height * base_width
def _centers(self, start, stop, num, device, dtype) -> torch.Tensor:
edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype)
return (edges[:-1] + edges[1:]) / 2
def _get_positions(
self,
num_frames: int,
height: int,
width: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
scale = (self.target_area / (height * width)) ** 0.5
t = torch.arange(num_frames, device=device, dtype=dtype)
h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype)
w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype)
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3)
return positions
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
freqs = torch.einsum("nd,dhf->nhf", pos, freqs.float())
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
def forward(
self,
pos_frequencies: torch.Tensor,
num_frames: int,
height: int,
width: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
pos = self._get_positions(num_frames, height, width, device, dtype)
rope_cos, rope_sin = self._create_rope(pos_frequencies, pos)
return rope_cos, rope_sin
@maybe_allow_in_graph
class MochiTransformer3DModel(ModelMixin, ConfigMixin):
r"""
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
Args:
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
num_attention_heads (`int`, defaults to `24`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
num_layers (`int`, defaults to `48`):
The number of layers of Transformer blocks to use.
in_channels (`int`, defaults to `12`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `None`):
The number of channels in the output.
qk_norm (`str`, defaults to `"rms_norm"`):
The normalization layer to use.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
time_embed_dim (`int`, defaults to `256`):
Output dimension of timestep embeddings.
activation_fn (`str`, defaults to `"swiglu"`):
Activation function to use in feed-forward.
max_sequence_length (`int`, defaults to `256`):
The maximum sequence length of text embeddings supported.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 2,
num_attention_heads: int = 24,
attention_head_dim: int = 128,
num_layers: int = 48,
pooled_projection_dim: int = 1536,
in_channels: int = 12,
out_channels: Optional[int] = None,
qk_norm: str = "rms_norm",
text_embed_dim: int = 4096,
time_embed_dim: int = 256,
activation_fn: str = "swiglu",
max_sequence_length: int = 256,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
pos_embed_type=None,
)
self.time_embed = MochiCombinedTimestepCaptionEmbedding(
embedding_dim=inner_dim,
pooled_projection_dim=pooled_projection_dim,
text_embed_dim=text_embed_dim,
time_embed_dim=time_embed_dim,
num_attention_heads=8,
)
self.pos_frequencies = nn.Parameter(torch.full((3, num_attention_heads, attention_head_dim // 2), 0.0))
self.rope = MochiRoPE()
self.transformer_blocks = nn.ModuleList(
[
MochiTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
pooled_projection_dim=pooled_projection_dim,
qk_norm=qk_norm,
activation_fn=activation_fn,
context_pre_only=i == num_layers - 1,
)
for i in range(num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm"
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
return_dict: bool = True,
) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p = self.config.patch_size
post_patch_height = height // p
post_patch_width = width // p
temb, encoder_hidden_states = self.time_embed(
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
image_rotary_emb = self.rope(
self.pos_frequencies,
num_frames,
post_patch_height,
post_patch_width,
device=hidden_states.device,
dtype=torch.float32,
)
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
...@@ -247,6 +247,7 @@ else: ...@@ -247,6 +247,7 @@ else:
"MarigoldNormalsPipeline", "MarigoldNormalsPipeline",
] ]
) )
_import_structure["mochi"] = ["MochiPipeline"]
_import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"] _import_structure["pia"] = ["PIAPipeline"]
...@@ -571,6 +572,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -571,6 +572,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MarigoldDepthPipeline, MarigoldDepthPipeline,
MarigoldNormalsPipeline, MarigoldNormalsPipeline,
) )
from .mochi import MochiPipeline
from .musicldm import MusicLDMPipeline from .musicldm import MusicLDMPipeline
from .pag import ( from .pag import (
AnimateDiffPAGPipeline, AnimateDiffPAGPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_mochi"] = ["MochiPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_mochi import MochiPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
This diff is collapsed.
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class MochiPipelineOutput(BaseOutput):
r"""
Output class for CogVideo pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment