"examples/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "1a003c3fc21f476b5fb0aa2215426c914f229729"
Unverified Commit a3904d7e authored by XCL's avatar XCL Committed by GitHub
Browse files

[Tencent Hunyuan Team] Add HunyuanDiT-v1.2 Support (#8747)



* add v1.2 support

---------
Co-authored-by: default avatarxingchaoliu <xingchaoliu@tencent.com>
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
parent 7bfc1ee1
...@@ -717,7 +717,14 @@ class HunyuanDiTAttentionPool(nn.Module): ...@@ -717,7 +717,14 @@ class HunyuanDiTAttentionPool(nn.Module):
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048): def __init__(
self,
embedding_dim,
pooled_projection_dim=1024,
seq_len=256,
cross_attention_dim=2048,
use_style_cond_and_image_meta_size=True,
):
super().__init__() super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
...@@ -726,9 +733,15 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): ...@@ -726,9 +733,15 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
self.pooler = HunyuanDiTAttentionPool( self.pooler = HunyuanDiTAttentionPool(
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
) )
# Here we use a default learned embedder layer for future extension. # Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(1, embedding_dim) self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim if use_style_cond_and_image_meta_size:
self.style_embedder = nn.Embedding(1, embedding_dim)
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
else:
extra_in_dim = pooled_projection_dim
self.extra_embedder = PixArtAlphaTextProjection( self.extra_embedder = PixArtAlphaTextProjection(
in_features=extra_in_dim, in_features=extra_in_dim,
hidden_size=embedding_dim * 4, hidden_size=embedding_dim * 4,
...@@ -743,16 +756,20 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): ...@@ -743,16 +756,20 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
# extra condition1: text # extra condition1: text
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
# extra condition2: image meta size embdding if self.use_style_cond_and_image_meta_size:
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0) # extra condition2: image meta size embdding
image_meta_size = image_meta_size.to(dtype=hidden_dtype) image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) image_meta_size = image_meta_size.to(dtype=hidden_dtype)
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
# extra condition3: style embedding # extra condition3: style embedding
style_embedding = self.style_embedder(style) # (N, embedding_dim) style_embedding = self.style_embedder(style) # (N, embedding_dim)
# Concatenate all extra vectors
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
else:
extra_cond = torch.cat([pooled_projections], dim=1)
# Concatenate all extra vectors
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
return conditioning return conditioning
......
...@@ -249,6 +249,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -249,6 +249,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
The length of the clip text embedding. The length of the clip text embedding.
text_len_t5 (`int`, *optional*): text_len_t5 (`int`, *optional*):
The length of the T5 text embedding. The length of the T5 text embedding.
use_style_cond_and_image_meta_size (`bool`, *optional*):
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
""" """
@register_to_config @register_to_config
...@@ -270,6 +272,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -270,6 +272,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
pooled_projection_dim: int = 1024, pooled_projection_dim: int = 1024,
text_len: int = 77, text_len: int = 77,
text_len_t5: int = 256, text_len_t5: int = 256,
use_style_cond_and_image_meta_size: bool = True,
): ):
super().__init__() super().__init__()
self.out_channels = in_channels * 2 if learn_sigma else in_channels self.out_channels = in_channels * 2 if learn_sigma else in_channels
...@@ -301,6 +304,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -301,6 +304,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
pooled_projection_dim=pooled_projection_dim, pooled_projection_dim=pooled_projection_dim,
seq_len=text_len_t5, seq_len=text_len_t5,
cross_attention_dim=cross_attention_dim_t5, cross_attention_dim=cross_attention_dim_t5,
use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
) )
# HunyuanDiT Blocks # HunyuanDiT Blocks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment