"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7bd50cabafc60bf45ebbe1957b125d3f4c758ba8"
Unverified Commit 3e71a206 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[refactor embeddings]pixart-alpha (#6212)



pixart-alpha
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent bf40d7d8
...@@ -729,7 +729,7 @@ class PositionNet(nn.Module): ...@@ -729,7 +729,7 @@ class PositionNet(nn.Module):
return objs return objs
class CombinedTimestepSizeEmbeddings(nn.Module): class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
""" """
For PixArt-Alpha. For PixArt-Alpha.
...@@ -746,45 +746,27 @@ class CombinedTimestepSizeEmbeddings(nn.Module): ...@@ -746,45 +746,27 @@ class CombinedTimestepSizeEmbeddings(nn.Module):
self.use_additional_conditions = use_additional_conditions self.use_additional_conditions = use_additional_conditions
if use_additional_conditions: if use_additional_conditions:
self.use_additional_conditions = True
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
if size.ndim == 1:
size = size[:, None]
if size.shape[0] != batch_size:
size = size.repeat(batch_size // size.shape[0], 1)
if size.shape[0] != batch_size:
raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
current_batch_size, dims = size.shape[0], size.shape[1]
size = size.reshape(-1)
size_freq = self.additional_condition_proj(size).to(size.dtype)
size_emb = embedder(size_freq)
size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
return size_emb
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep) timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
if self.use_additional_conditions: if self.use_additional_conditions:
resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
aspect_ratio = self.apply_condition( resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
) aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
else: else:
conditioning = timesteps_emb conditioning = timesteps_emb
return conditioning return conditioning
class CaptionProjection(nn.Module): class PixArtAlphaTextProjection(nn.Module):
""" """
Projects caption embeddings. Also handles dropout for classifier-free guidance. Projects caption embeddings. Also handles dropout for classifier-free guidance.
...@@ -796,9 +778,8 @@ class CaptionProjection(nn.Module): ...@@ -796,9 +778,8 @@ class CaptionProjection(nn.Module):
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
self.act_1 = nn.GELU(approximate="tanh") self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))
def forward(self, caption, force_drop_ids=None): def forward(self, caption):
hidden_states = self.linear_1(caption) hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states) hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states) hidden_states = self.linear_2(hidden_states)
......
...@@ -20,7 +20,7 @@ import torch.nn as nn ...@@ -20,7 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .activations import get_activation from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
class AdaLayerNorm(nn.Module): class AdaLayerNorm(nn.Module):
...@@ -91,7 +91,7 @@ class AdaLayerNormSingle(nn.Module): ...@@ -91,7 +91,7 @@ class AdaLayerNormSingle(nn.Module):
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
super().__init__() super().__init__()
self.emb = CombinedTimestepSizeEmbeddings( self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
) )
......
...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config ...@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
from .attention import BasicTransformerBlock from .attention import BasicTransformerBlock
from .embeddings import CaptionProjection, PatchEmbed from .embeddings import PatchEmbed, PixArtAlphaTextProjection
from .lora import LoRACompatibleConv, LoRACompatibleLinear from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .normalization import AdaLayerNormSingle from .normalization import AdaLayerNormSingle
...@@ -235,7 +235,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -235,7 +235,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.caption_projection = None self.caption_projection = None
if caption_channels is not None: if caption_channels is not None:
self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.gradient_checkpointing = False self.gradient_checkpointing = False
......
...@@ -853,6 +853,11 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -853,6 +853,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
if do_classifier_free_guidance:
resolution = torch.cat([resolution, resolution], dim=0)
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
# 7. Denoising loop # 7. Denoising loop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment