Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
3e71a206
Unverified
Commit
3e71a206
authored
Dec 19, 2023
by
YiYi Xu
Committed by
GitHub
Dec 19, 2023
Browse files
[refactor embeddings]pixart-alpha (#6212)
pixart-alpha Co-authored-by:
yiyixuxu
<
yixu310@gmail,com
>
parent
bf40d7d8
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
31 deletions
+17
-31
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+8
-27
src/diffusers/models/normalization.py
src/diffusers/models/normalization.py
+2
-2
src/diffusers/models/transformer_2d.py
src/diffusers/models/transformer_2d.py
+2
-2
src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
...diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+5
-0
No files found.
src/diffusers/models/embeddings.py
View file @
3e71a206
...
...
@@ -729,7 +729,7 @@ class PositionNet(nn.Module):
return
objs
class
CombinedTimestepSizeEmbeddings
(
nn
.
Module
):
class
PixArtAlpha
CombinedTimestepSizeEmbeddings
(
nn
.
Module
):
"""
For PixArt-Alpha.
...
...
@@ -746,45 +746,27 @@ class CombinedTimestepSizeEmbeddings(nn.Module):
self
.
use_additional_conditions
=
use_additional_conditions
if
use_additional_conditions
:
self
.
use_additional_conditions
=
True
self
.
additional_condition_proj
=
Timesteps
(
num_channels
=
256
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
self
.
resolution_embedder
=
TimestepEmbedding
(
in_channels
=
256
,
time_embed_dim
=
size_emb_dim
)
self
.
aspect_ratio_embedder
=
TimestepEmbedding
(
in_channels
=
256
,
time_embed_dim
=
size_emb_dim
)
def
apply_condition
(
self
,
size
:
torch
.
Tensor
,
batch_size
:
int
,
embedder
:
nn
.
Module
):
if
size
.
ndim
==
1
:
size
=
size
[:,
None
]
if
size
.
shape
[
0
]
!=
batch_size
:
size
=
size
.
repeat
(
batch_size
//
size
.
shape
[
0
],
1
)
if
size
.
shape
[
0
]
!=
batch_size
:
raise
ValueError
(
f
"`batch_size` should be
{
size
.
shape
[
0
]
}
but found
{
batch_size
}
."
)
current_batch_size
,
dims
=
size
.
shape
[
0
],
size
.
shape
[
1
]
size
=
size
.
reshape
(
-
1
)
size_freq
=
self
.
additional_condition_proj
(
size
).
to
(
size
.
dtype
)
size_emb
=
embedder
(
size_freq
)
size_emb
=
size_emb
.
reshape
(
current_batch_size
,
dims
*
self
.
outdim
)
return
size_emb
def
forward
(
self
,
timestep
,
resolution
,
aspect_ratio
,
batch_size
,
hidden_dtype
):
timesteps_proj
=
self
.
time_proj
(
timestep
)
timesteps_emb
=
self
.
timestep_embedder
(
timesteps_proj
.
to
(
dtype
=
hidden_dtype
))
# (N, D)
if
self
.
use_additional_conditions
:
resolution
=
self
.
a
pply
_condition
(
resolution
,
batch_size
=
batch_size
,
embedder
=
self
.
resolution_embedder
)
aspect_ratio
=
self
.
apply_condition
(
aspect_ratio
,
batch_size
=
batch_size
,
embedder
=
self
.
aspect_ratio_embedder
)
conditioning
=
timesteps_emb
+
torch
.
cat
([
resolution
,
aspect_ratio
],
dim
=
1
)
resolution
_emb
=
self
.
a
dditional
_condition
_proj
(
resolution
.
flatten
()).
to
(
hidden_dtype
)
resolution_emb
=
self
.
resolution_embedder
(
resolution_emb
).
reshape
(
batch_size
,
-
1
)
aspect_ratio
_emb
=
self
.
additional_condition_proj
(
aspect_ratio
.
flatten
()).
to
(
hidden_dtype
)
aspect_ratio_emb
=
self
.
aspect_ratio_embedder
(
aspect_ratio_emb
).
reshape
(
batch_size
,
-
1
)
conditioning
=
timesteps_emb
+
torch
.
cat
([
resolution
_emb
,
aspect_ratio
_emb
],
dim
=
1
)
else
:
conditioning
=
timesteps_emb
return
conditioning
class
Caption
Projection
(
nn
.
Module
):
class
PixArtAlphaText
Projection
(
nn
.
Module
):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
...
...
@@ -796,9 +778,8 @@ class CaptionProjection(nn.Module):
self
.
linear_1
=
nn
.
Linear
(
in_features
=
in_features
,
out_features
=
hidden_size
,
bias
=
True
)
self
.
act_1
=
nn
.
GELU
(
approximate
=
"tanh"
)
self
.
linear_2
=
nn
.
Linear
(
in_features
=
hidden_size
,
out_features
=
hidden_size
,
bias
=
True
)
self
.
register_buffer
(
"y_embedding"
,
nn
.
Parameter
(
torch
.
randn
(
num_tokens
,
in_features
)
/
in_features
**
0.5
))
def
forward
(
self
,
caption
,
force_drop_ids
=
None
):
def
forward
(
self
,
caption
):
hidden_states
=
self
.
linear_1
(
caption
)
hidden_states
=
self
.
act_1
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
...
...
src/diffusers/models/normalization.py
View file @
3e71a206
...
...
@@ -20,7 +20,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
.activations
import
get_activation
from
.embeddings
import
CombinedTimestepLabelEmbeddings
,
CombinedTimestepSizeEmbeddings
from
.embeddings
import
CombinedTimestepLabelEmbeddings
,
PixArtAlpha
CombinedTimestepSizeEmbeddings
class
AdaLayerNorm
(
nn
.
Module
):
...
...
@@ -91,7 +91,7 @@ class AdaLayerNormSingle(nn.Module):
def
__init__
(
self
,
embedding_dim
:
int
,
use_additional_conditions
:
bool
=
False
):
super
().
__init__
()
self
.
emb
=
CombinedTimestepSizeEmbeddings
(
self
.
emb
=
PixArtAlpha
CombinedTimestepSizeEmbeddings
(
embedding_dim
,
size_emb_dim
=
embedding_dim
//
3
,
use_additional_conditions
=
use_additional_conditions
)
...
...
src/diffusers/models/transformer_2d.py
View file @
3e71a206
...
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from
..models.embeddings
import
ImagePositionalEmbeddings
from
..utils
import
USE_PEFT_BACKEND
,
BaseOutput
,
deprecate
,
is_torch_version
from
.attention
import
BasicTransformerBlock
from
.embeddings
import
CaptionProjection
,
PatchEmbed
from
.embeddings
import
PatchEmbed
,
PixArtAlphaTextProjection
from
.lora
import
LoRACompatibleConv
,
LoRACompatibleLinear
from
.modeling_utils
import
ModelMixin
from
.normalization
import
AdaLayerNormSingle
...
...
@@ -235,7 +235,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self
.
caption_projection
=
None
if
caption_channels
is
not
None
:
self
.
caption_projection
=
Caption
Projection
(
in_features
=
caption_channels
,
hidden_size
=
inner_dim
)
self
.
caption_projection
=
PixArtAlphaText
Projection
(
in_features
=
caption_channels
,
hidden_size
=
inner_dim
)
self
.
gradient_checkpointing
=
False
...
...
src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
View file @
3e71a206
...
...
@@ -853,6 +853,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
aspect_ratio
=
torch
.
tensor
([
float
(
height
/
width
)]).
repeat
(
batch_size
*
num_images_per_prompt
,
1
)
resolution
=
resolution
.
to
(
dtype
=
prompt_embeds
.
dtype
,
device
=
device
)
aspect_ratio
=
aspect_ratio
.
to
(
dtype
=
prompt_embeds
.
dtype
,
device
=
device
)
if
do_classifier_free_guidance
:
resolution
=
torch
.
cat
([
resolution
,
resolution
],
dim
=
0
)
aspect_ratio
=
torch
.
cat
([
aspect_ratio
,
aspect_ratio
],
dim
=
0
)
added_cond_kwargs
=
{
"resolution"
:
resolution
,
"aspect_ratio"
:
aspect_ratio
}
# 7. Denoising loop
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment