Unverified Commit 13f20c7f authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[refactor] SD3 docs & remove additional code (#10882)

* update

* update

* update
parent 87599691
...@@ -1410,7 +1410,7 @@ class JointAttnProcessor2_0: ...@@ -1410,7 +1410,7 @@ class JointAttnProcessor2_0:
def __init__(self): def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__( def __call__(
self, self,
......
...@@ -40,6 +40,48 @@ class SD3ControlNetOutput(BaseOutput): ...@@ -40,6 +40,48 @@ class SD3ControlNetOutput(BaseOutput):
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
Parameters:
sample_size (`int`, defaults to `128`):
The width/height of the latents. This is fixed during training since it is used to learn a number of
position embeddings.
patch_size (`int`, defaults to `2`):
Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `16`):
The number of latent channels in the input.
num_layers (`int`, defaults to `18`):
The number of layers of transformer blocks to use.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
num_attention_heads (`int`, defaults to `18`):
The number of heads to use for multi-head attention.
joint_attention_dim (`int`, defaults to `4096`):
The embedding dimension to use for joint text-image attention.
caption_projection_dim (`int`, defaults to `1152`):
The embedding dimension of caption embeddings.
pooled_projection_dim (`int`, defaults to `2048`):
The embedding dimension of pooled text projections.
out_channels (`int`, defaults to `16`):
The number of latent channels in the output.
pos_embed_max_size (`int`, defaults to `96`):
The maximum latent height/width of positional embeddings.
extra_conditioning_channels (`int`, defaults to `0`):
The number of extra channels to use for conditioning for patch embedding.
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
The number of dual-stream transformer blocks to use.
qk_norm (`str`, *optional*, defaults to `None`):
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
pos_embed_type (`str`, defaults to `"sincos"`):
The type of positional embedding to use. Choose between `"sincos"` and `None`.
use_pos_embed (`bool`, defaults to `True`):
Whether to use positional embeddings.
force_zeros_for_pooled_projection (`bool`, defaults to `True`):
Whether to force zeros for pooled projection embeddings. This is handled in the pipelines by reading the
config value of the ControlNet model.
"""
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@register_to_config @register_to_config
...@@ -93,7 +135,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -93,7 +135,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
JointTransformerBlock( JointTransformerBlock(
dim=self.inner_dim, dim=self.inner_dim,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim, attention_head_dim=attention_head_dim,
context_pre_only=False, context_pre_only=False,
qk_norm=qk_norm, qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False, use_dual_attention=True if i in dual_attention_layers else False,
...@@ -108,7 +150,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -108,7 +150,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
SD3SingleTransformerBlock( SD3SingleTransformerBlock(
dim=self.inner_dim, dim=self.inner_dim,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim, attention_head_dim=attention_head_dim,
) )
for _ in range(num_layers) for _ in range(num_layers)
] ]
...@@ -297,28 +339,28 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -297,28 +339,28 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor, controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0, conditioning_scale: float = 1.0,
encoder_hidden_states: torch.FloatTensor = None, encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.FloatTensor = None, pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None, timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ) -> Union[torch.Tensor, Transformer2DModelOutput]:
""" """
The [`SD3Transformer2DModel`] forward method. The [`SD3Transformer2DModel`] forward method.
Args: Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`. Input `hidden_states`.
controlnet_cond (`torch.Tensor`): controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`): conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs. The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions. from the embeddings of input conditions.
timestep ( `torch.LongTensor`): timestep ( `torch.LongTensor`):
Used to indicate denoising step. Used to indicate denoising step.
...@@ -437,11 +479,11 @@ class SD3MultiControlNetModel(ModelMixin): ...@@ -437,11 +479,11 @@ class SD3MultiControlNetModel(ModelMixin):
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
controlnet_cond: List[torch.tensor], controlnet_cond: List[torch.tensor],
conditioning_scale: List[float], conditioning_scale: List[float],
pooled_projections: torch.FloatTensor, pooled_projections: torch.Tensor,
encoder_hidden_states: torch.FloatTensor = None, encoder_hidden_states: torch.Tensor = None,
timestep: torch.LongTensor = None, timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
......
...@@ -15,7 +15,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -15,7 +15,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
...@@ -39,17 +38,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -39,17 +38,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph @maybe_allow_in_graph
class SD3SingleTransformerBlock(nn.Module): class SD3SingleTransformerBlock(nn.Module):
r"""
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
"""
def __init__( def __init__(
self, self,
dim: int, dim: int,
...@@ -59,21 +47,13 @@ class SD3SingleTransformerBlock(nn.Module): ...@@ -59,21 +47,13 @@ class SD3SingleTransformerBlock(nn.Module):
super().__init__() super().__init__()
self.norm1 = AdaLayerNormZero(dim) self.norm1 = AdaLayerNormZero(dim)
if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention( self.attn = Attention(
query_dim=dim, query_dim=dim,
dim_head=attention_head_dim, dim_head=attention_head_dim,
heads=num_attention_heads, heads=num_attention_heads,
out_dim=dim, out_dim=dim,
bias=True, bias=True,
processor=processor, processor=JointAttnProcessor2_0(),
eps=1e-6, eps=1e-6,
) )
...@@ -81,23 +61,17 @@ class SD3SingleTransformerBlock(nn.Module): ...@@ -81,23 +61,17 @@ class SD3SingleTransformerBlock(nn.Module):
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor): def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
# 1. Attention
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention. attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None)
attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
)
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output hidden_states = hidden_states + attn_output
# 2. Feed Forward
norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
ff_output = self.ff(norm_hidden_states) ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output hidden_states = hidden_states + ff_output
return hidden_states return hidden_states
...@@ -107,26 +81,40 @@ class SD3Transformer2DModel( ...@@ -107,26 +81,40 @@ class SD3Transformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
): ):
""" """
The Transformer model introduced in Stable Diffusion 3. The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
Reference: https://arxiv.org/abs/2403.03206
Parameters: Parameters:
sample_size (`int`): The width of the latent images. This is fixed during training since sample_size (`int`, defaults to `128`):
it is used to learn a number of position embeddings. The width/height of the latents. This is fixed during training since it is used to learn a number of
patch_size (`int`): Patch size to turn the input data into small patches. position embeddings.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. patch_size (`int`, defaults to `2`):
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use. Patch size to turn the input data into small patches.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. in_channels (`int`, defaults to `16`):
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. The number of latent channels in the input.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. num_layers (`int`, defaults to `18`):
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. The number of layers of transformer blocks to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. attention_head_dim (`int`, defaults to `64`):
out_channels (`int`, defaults to 16): Number of output channels. The number of channels in each head.
num_attention_heads (`int`, defaults to `18`):
The number of heads to use for multi-head attention.
joint_attention_dim (`int`, defaults to `4096`):
The embedding dimension to use for joint text-image attention.
caption_projection_dim (`int`, defaults to `1152`):
The embedding dimension of caption embeddings.
pooled_projection_dim (`int`, defaults to `2048`):
The embedding dimension of pooled text projections.
out_channels (`int`, defaults to `16`):
The number of latent channels in the output.
pos_embed_max_size (`int`, defaults to `96`):
The maximum latent height/width of positional embeddings.
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
The number of dual-stream transformer blocks to use.
qk_norm (`str`, *optional*, defaults to `None`):
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_no_split_modules = ["JointTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config @register_to_config
...@@ -149,36 +137,33 @@ class SD3Transformer2DModel( ...@@ -149,36 +137,33 @@ class SD3Transformer2DModel(
qk_norm: Optional[str] = None, qk_norm: Optional[str] = None,
): ):
super().__init__() super().__init__()
default_out_channels = in_channels self.out_channels = out_channels if out_channels is not None else in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels self.inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = PatchEmbed( self.pos_embed = PatchEmbed(
height=self.config.sample_size, height=sample_size,
width=self.config.sample_size, width=sample_size,
patch_size=self.config.patch_size, patch_size=patch_size,
in_channels=self.config.in_channels, in_channels=in_channels,
embed_dim=self.inner_dim, embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size, # hard-code for now. pos_embed_max_size=pos_embed_max_size, # hard-code for now.
) )
self.time_text_embed = CombinedTimestepTextProjEmbeddings( self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
) )
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
# `attention_head_dim` is doubled to account for the mixing.
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[ [
JointTransformerBlock( JointTransformerBlock(
dim=self.inner_dim, dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads, num_attention_heads=num_attention_heads,
attention_head_dim=self.config.attention_head_dim, attention_head_dim=attention_head_dim,
context_pre_only=i == num_layers - 1, context_pre_only=i == num_layers - 1,
qk_norm=qk_norm, qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False, use_dual_attention=True if i in dual_attention_layers else False,
) )
for i in range(self.config.num_layers) for i in range(num_layers)
] ]
) )
...@@ -331,24 +316,24 @@ class SD3Transformer2DModel( ...@@ -331,24 +316,24 @@ class SD3Transformer2DModel(
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.FloatTensor = None, encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.FloatTensor = None, pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None, timestep: torch.LongTensor = None,
block_controlnet_hidden_states: List = None, block_controlnet_hidden_states: List = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
skip_layers: Optional[List[int]] = None, skip_layers: Optional[List[int]] = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ) -> Union[torch.Tensor, Transformer2DModelOutput]:
""" """
The [`SD3Transformer2DModel`] forward method. The [`SD3Transformer2DModel`] forward method.
Args: Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`. Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`):
Embeddings projected from the embeddings of input conditions. Embeddings projected from the embeddings of input conditions.
timestep (`torch.LongTensor`): timestep (`torch.LongTensor`):
Used to indicate denoising step. Used to indicate denoising step.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment