Unverified Commit e2d037bb authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

minor doc/test update (#9734)



* update some docs and tests!

---------
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarapolinário <joaopaulo.passos@gmail.com>
parent bcd61fd3
...@@ -54,6 +54,11 @@ image = pipe( ...@@ -54,6 +54,11 @@ image = pipe(
image.save("sd3_hello_world.png") image.save("sd3_hello_world.png")
``` ```
**Note:** Stable Diffusion 3.5 can also be run using the SD3 pipeline, and all mentioned optimizations and techniques apply to it as well. In total there are three official models in the SD3 family:
- [`stabilityai/stable-diffusion-3-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers)
- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large)
- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo)
## Memory Optimisations for SD3 ## Memory Optimisations for SD3
SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware. SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
......
...@@ -16,10 +16,9 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext ...@@ -16,10 +16,9 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str) parser.add_argument("--checkpoint_path", type=str)
parser.add_argument("--output_path", type=str) parser.add_argument("--output_path", type=str)
parser.add_argument("--dtype", type=str, default="fp16") parser.add_argument("--dtype", type=str)
args = parser.parse_args() args = parser.parse_args()
dtype = torch.float16 if args.dtype == "fp16" else torch.float32
def load_original_checkpoint(ckpt_path): def load_original_checkpoint(ckpt_path):
...@@ -40,7 +39,9 @@ def swap_scale_shift(weight, dim): ...@@ -40,7 +39,9 @@ def swap_scale_shift(weight, dim):
return new_weight return new_weight
def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim): def convert_sd3_transformer_checkpoint_to_diffusers(
original_state_dict, num_layers, caption_projection_dim, dual_attention_layers, has_qk_norm
):
converted_state_dict = {} converted_state_dict = {}
# Positional and patch embeddings. # Positional and patch embeddings.
...@@ -110,6 +111,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay ...@@ -110,6 +111,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v]) converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias]) converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
# qk norm
if has_qk_norm:
converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn.ln_q.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn.ln_k.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.context_block.attn.ln_q.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.context_block.attn.ln_k.weight"
)
# output projections. # output projections.
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop( converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn.proj.weight" f"joint_blocks.{i}.x_block.attn.proj.weight"
...@@ -125,6 +141,39 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay ...@@ -125,6 +141,39 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
f"joint_blocks.{i}.context_block.attn.proj.bias" f"joint_blocks.{i}.context_block.attn.proj.bias"
) )
# attn2
if i in dual_attention_layers:
# Q, K, V
sample_q2, sample_k2, sample_v2 = torch.chunk(
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
)
sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
)
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
# qk norm
if has_qk_norm:
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
)
# output projections.
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn2.proj.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn2.proj.bias"
)
# norms. # norms.
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop( converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight" f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
...@@ -195,25 +244,79 @@ def is_vae_in_checkpoint(original_state_dict): ...@@ -195,25 +244,79 @@ def is_vae_in_checkpoint(original_state_dict):
) )
def get_attn2_layers(state_dict):
attn2_layers = []
for key in state_dict.keys():
if "attn2." in key:
# Extract the layer number from the key
layer_num = int(key.split(".")[1])
attn2_layers.append(layer_num)
return tuple(sorted(set(attn2_layers)))
def get_pos_embed_max_size(state_dict):
num_patches = state_dict["pos_embed"].shape[1]
pos_embed_max_size = int(num_patches**0.5)
return pos_embed_max_size
def get_caption_projection_dim(state_dict):
caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
return caption_projection_dim
def main(args): def main(args):
original_ckpt = load_original_checkpoint(args.checkpoint_path) original_ckpt = load_original_checkpoint(args.checkpoint_path)
original_dtype = next(iter(original_ckpt.values())).dtype
# Initialize dtype with a default value
dtype = None
if args.dtype is None:
dtype = original_dtype
elif args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
if dtype != original_dtype:
print(
f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution."
)
num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401 num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401
caption_projection_dim = 1536
caption_projection_dim = get_caption_projection_dim(original_ckpt)
# () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
attn2_layers = get_attn2_layers(original_ckpt)
# sd3.5 use qk norm("rms_norm")
has_qk_norm = any("ln_q" in key for key in original_ckpt.keys())
# sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192
pos_embed_max_size = get_pos_embed_max_size(original_ckpt)
converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers( converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
original_ckpt, num_layers, caption_projection_dim original_ckpt, num_layers, caption_projection_dim, attn2_layers, has_qk_norm
) )
with CTX(): with CTX():
transformer = SD3Transformer2DModel( transformer = SD3Transformer2DModel(
sample_size=64, sample_size=128,
patch_size=2, patch_size=2,
in_channels=16, in_channels=16,
joint_attention_dim=4096, joint_attention_dim=4096,
num_layers=num_layers, num_layers=num_layers,
caption_projection_dim=caption_projection_dim, caption_projection_dim=caption_projection_dim,
num_attention_heads=24, num_attention_heads=num_layers,
pos_embed_max_size=192, pos_embed_max_size=pos_embed_max_size,
qk_norm="rms_norm" if has_qk_norm else None,
dual_attention_layers=attn2_layers,
) )
if is_accelerate_available(): if is_accelerate_available():
load_model_dict_into_meta(transformer, converted_transformer_state_dict) load_model_dict_into_meta(transformer, converted_transformer_state_dict)
......
...@@ -22,7 +22,7 @@ from ..utils.torch_utils import maybe_allow_in_graph ...@@ -22,7 +22,7 @@ from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
from .attention_processor import Attention, JointAttnProcessor2_0 from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -100,13 +100,25 @@ class JointTransformerBlock(nn.Module): ...@@ -100,13 +100,25 @@ class JointTransformerBlock(nn.Module):
processing of `context` conditions. processing of `context` conditions.
""" """
def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False): def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
context_pre_only: bool = False,
qk_norm: Optional[str] = None,
use_dual_attention: bool = False,
):
super().__init__() super().__init__()
self.use_dual_attention = use_dual_attention
self.context_pre_only = context_pre_only self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
self.norm1 = AdaLayerNormZero(dim) if use_dual_attention:
self.norm1 = SD35AdaLayerNormZeroX(dim)
else:
self.norm1 = AdaLayerNormZero(dim)
if context_norm_type == "ada_norm_continous": if context_norm_type == "ada_norm_continous":
self.norm1_context = AdaLayerNormContinuous( self.norm1_context = AdaLayerNormContinuous(
...@@ -118,12 +130,14 @@ class JointTransformerBlock(nn.Module): ...@@ -118,12 +130,14 @@ class JointTransformerBlock(nn.Module):
raise ValueError( raise ValueError(
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
) )
if hasattr(F, "scaled_dot_product_attention"): if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0() processor = JointAttnProcessor2_0()
else: else:
raise ValueError( raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function." "The current PyTorch version does not support the `scaled_dot_product_attention` function."
) )
self.attn = Attention( self.attn = Attention(
query_dim=dim, query_dim=dim,
cross_attention_dim=None, cross_attention_dim=None,
...@@ -134,8 +148,25 @@ class JointTransformerBlock(nn.Module): ...@@ -134,8 +148,25 @@ class JointTransformerBlock(nn.Module):
context_pre_only=context_pre_only, context_pre_only=context_pre_only,
bias=True, bias=True,
processor=processor, processor=processor,
qk_norm=qk_norm,
eps=1e-6,
) )
if use_dual_attention:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=1e-6,
)
else:
self.attn2 = None
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
...@@ -159,7 +190,12 @@ class JointTransformerBlock(nn.Module): ...@@ -159,7 +190,12 @@ class JointTransformerBlock(nn.Module):
def forward( def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
): ):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
)
else:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
if self.context_pre_only: if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
...@@ -177,6 +213,11 @@ class JointTransformerBlock(nn.Module): ...@@ -177,6 +213,11 @@ class JointTransformerBlock(nn.Module):
attn_output = gate_msa.unsqueeze(1) * attn_output attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output hidden_states = hidden_states + attn_output
if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2
norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None: if self._chunk_size is not None:
......
...@@ -193,7 +193,7 @@ class Attention(nn.Module): ...@@ -193,7 +193,7 @@ class Attention(nn.Module):
self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps)
else: else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
if cross_attention_norm is None: if cross_attention_norm is None:
self.norm_cross = None self.norm_cross = None
...@@ -250,6 +250,10 @@ class Attention(nn.Module): ...@@ -250,6 +250,10 @@ class Attention(nn.Module):
elif qk_norm == "rms_norm": elif qk_norm == "rms_norm":
self.norm_added_q = RMSNorm(dim_head, eps=eps) self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps) self.norm_added_k = RMSNorm(dim_head, eps=eps)
else:
raise ValueError(
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
)
else: else:
self.norm_added_q = None self.norm_added_q = None
self.norm_added_k = None self.norm_added_k = None
...@@ -1050,61 +1054,72 @@ class JointAttnProcessor2_0: ...@@ -1050,61 +1054,72 @@ class JointAttnProcessor2_0:
) -> torch.FloatTensor: ) -> torch.FloatTensor:
residual = hidden_states residual = hidden_states
input_ndim = hidden_states.ndim batch_size = hidden_states.shape[0]
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections. # `sample` projections.
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states) key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states) value = attn.to_v(hidden_states)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# attention
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
inner_dim = key.shape[-1] inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs. if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = ( # Split the attention outputs.
hidden_states[:, : residual.shape[1]], hidden_states, encoder_hidden_states = (
hidden_states[:, residual.shape[1] :], hidden_states[:, : residual.shape[1]],
) hidden_states[:, residual.shape[1] :],
)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4: if encoder_hidden_states is not None:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) return hidden_states, encoder_hidden_states
if context_input_ndim == 4: else:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) return hidden_states
return hidden_states, encoder_hidden_states
class PAGJointAttnProcessor2_0: class PAGJointAttnProcessor2_0:
......
...@@ -97,6 +97,40 @@ class FP32LayerNorm(nn.LayerNorm): ...@@ -97,6 +97,40 @@ class FP32LayerNorm(nn.LayerNorm):
).to(origin_dtype) ).to(origin_dtype)
class SD35AdaLayerNormZeroX(nn.Module):
r"""
Norm layer adaptive layer norm zero (AdaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
else:
raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.")
def forward(
self,
hidden_states: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, ...]:
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
9, dim=1
)
norm_hidden_states = self.norm(hidden_states)
hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
class AdaLayerNormZero(nn.Module): class AdaLayerNormZero(nn.Module):
r""" r"""
Norm layer adaptive layer norm zero (adaLN-Zero). Norm layer adaptive layer norm zero (adaLN-Zero).
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -69,6 +69,10 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -69,6 +69,10 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
pooled_projection_dim: int = 2048, pooled_projection_dim: int = 2048,
out_channels: int = 16, out_channels: int = 16,
pos_embed_max_size: int = 96, pos_embed_max_size: int = 96,
dual_attention_layers: Tuple[
int, ...
] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
qk_norm: Optional[str] = None,
): ):
super().__init__() super().__init__()
default_out_channels = in_channels default_out_channels = in_channels
...@@ -97,6 +101,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -97,6 +101,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
num_attention_heads=self.config.num_attention_heads, num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim, attention_head_dim=self.config.attention_head_dim,
context_pre_only=i == num_layers - 1, context_pre_only=i == num_layers - 1,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
) )
for i in range(self.config.num_layers) for i in range(self.config.num_layers)
] ]
......
...@@ -73,6 +73,65 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase): ...@@ -73,6 +73,65 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
"joint_attention_dim": 32, "joint_attention_dim": 32,
"pooled_projection_dim": 64, "pooled_projection_dim": 64,
"out_channels": 4, "out_channels": 4,
"pos_embed_max_size": 96,
"dual_attention_layers": (),
"qk_norm": None,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
main_input_name = "hidden_states"
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = width = embedding_dim = 32
pooled_embedding_dim = embedding_dim * 2
sequence_length = 154
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"sample_size": 32,
"patch_size": 1,
"in_channels": 4,
"num_layers": 2,
"attention_head_dim": 8,
"num_attention_heads": 4,
"caption_projection_dim": 32,
"joint_attention_dim": 32,
"pooled_projection_dim": 64,
"out_channels": 4,
"pos_embed_max_size": 96,
"dual_attention_layers": (0,),
"qk_norm": "rms_norm",
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment