Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3e71a206
Unverified
Commit
3e71a206
authored
Dec 19, 2023
by
YiYi Xu
Committed by
GitHub
Dec 19, 2023
Browse files
[refactor embeddings]pixart-alpha (#6212)
pixart-alpha Co-authored-by:
yiyixuxu
<
yixu310@gmail,com
>
parent
bf40d7d8
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
31 deletions
+17
-31
src/diffusers/models/embeddings.py
src/diffusers/models/embeddings.py
+8
-27
src/diffusers/models/normalization.py
src/diffusers/models/normalization.py
+2
-2
src/diffusers/models/transformer_2d.py
src/diffusers/models/transformer_2d.py
+2
-2
src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
...diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+5
-0
No files found.
src/diffusers/models/embeddings.py
View file @
3e71a206
...
@@ -729,7 +729,7 @@ class PositionNet(nn.Module):
...
@@ -729,7 +729,7 @@ class PositionNet(nn.Module):
return
objs
return
objs
class
CombinedTimestepSizeEmbeddings
(
nn
.
Module
):
class
PixArtAlpha
CombinedTimestepSizeEmbeddings
(
nn
.
Module
):
"""
"""
For PixArt-Alpha.
For PixArt-Alpha.
...
@@ -746,45 +746,27 @@ class CombinedTimestepSizeEmbeddings(nn.Module):
...
@@ -746,45 +746,27 @@ class CombinedTimestepSizeEmbeddings(nn.Module):
self
.
use_additional_conditions
=
use_additional_conditions
self
.
use_additional_conditions
=
use_additional_conditions
if
use_additional_conditions
:
if
use_additional_conditions
:
self
.
use_additional_conditions
=
True
self
.
additional_condition_proj
=
Timesteps
(
num_channels
=
256
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
self
.
additional_condition_proj
=
Timesteps
(
num_channels
=
256
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
)
self
.
resolution_embedder
=
TimestepEmbedding
(
in_channels
=
256
,
time_embed_dim
=
size_emb_dim
)
self
.
resolution_embedder
=
TimestepEmbedding
(
in_channels
=
256
,
time_embed_dim
=
size_emb_dim
)
self
.
aspect_ratio_embedder
=
TimestepEmbedding
(
in_channels
=
256
,
time_embed_dim
=
size_emb_dim
)
self
.
aspect_ratio_embedder
=
TimestepEmbedding
(
in_channels
=
256
,
time_embed_dim
=
size_emb_dim
)
def
apply_condition
(
self
,
size
:
torch
.
Tensor
,
batch_size
:
int
,
embedder
:
nn
.
Module
):
if
size
.
ndim
==
1
:
size
=
size
[:,
None
]
if
size
.
shape
[
0
]
!=
batch_size
:
size
=
size
.
repeat
(
batch_size
//
size
.
shape
[
0
],
1
)
if
size
.
shape
[
0
]
!=
batch_size
:
raise
ValueError
(
f
"`batch_size` should be
{
size
.
shape
[
0
]
}
but found
{
batch_size
}
."
)
current_batch_size
,
dims
=
size
.
shape
[
0
],
size
.
shape
[
1
]
size
=
size
.
reshape
(
-
1
)
size_freq
=
self
.
additional_condition_proj
(
size
).
to
(
size
.
dtype
)
size_emb
=
embedder
(
size_freq
)
size_emb
=
size_emb
.
reshape
(
current_batch_size
,
dims
*
self
.
outdim
)
return
size_emb
def
forward
(
self
,
timestep
,
resolution
,
aspect_ratio
,
batch_size
,
hidden_dtype
):
def
forward
(
self
,
timestep
,
resolution
,
aspect_ratio
,
batch_size
,
hidden_dtype
):
timesteps_proj
=
self
.
time_proj
(
timestep
)
timesteps_proj
=
self
.
time_proj
(
timestep
)
timesteps_emb
=
self
.
timestep_embedder
(
timesteps_proj
.
to
(
dtype
=
hidden_dtype
))
# (N, D)
timesteps_emb
=
self
.
timestep_embedder
(
timesteps_proj
.
to
(
dtype
=
hidden_dtype
))
# (N, D)
if
self
.
use_additional_conditions
:
if
self
.
use_additional_conditions
:
resolution
=
self
.
a
pply
_condition
(
resolution
,
batch_size
=
batch_size
,
embedder
=
self
.
resolution_embedder
)
resolution
_emb
=
self
.
a
dditional
_condition
_proj
(
resolution
.
flatten
()).
to
(
hidden_dtype
)
aspect_ratio
=
self
.
apply_condition
(
resolution_emb
=
self
.
resolution_embedder
(
resolution_emb
).
reshape
(
batch_size
,
-
1
)
aspect_ratio
,
batch_size
=
batch_size
,
embedder
=
self
.
aspect_ratio_embedder
aspect_ratio
_emb
=
self
.
additional_condition_proj
(
aspect_ratio
.
flatten
()).
to
(
hidden_dtype
)
)
aspect_ratio_emb
=
self
.
aspect_ratio_embedder
(
aspect_ratio_emb
).
reshape
(
batch_size
,
-
1
)
conditioning
=
timesteps_emb
+
torch
.
cat
([
resolution
,
aspect_ratio
],
dim
=
1
)
conditioning
=
timesteps_emb
+
torch
.
cat
([
resolution
_emb
,
aspect_ratio
_emb
],
dim
=
1
)
else
:
else
:
conditioning
=
timesteps_emb
conditioning
=
timesteps_emb
return
conditioning
return
conditioning
class
Caption
Projection
(
nn
.
Module
):
class
PixArtAlphaText
Projection
(
nn
.
Module
):
"""
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Projects caption embeddings. Also handles dropout for classifier-free guidance.
...
@@ -796,9 +778,8 @@ class CaptionProjection(nn.Module):
...
@@ -796,9 +778,8 @@ class CaptionProjection(nn.Module):
self
.
linear_1
=
nn
.
Linear
(
in_features
=
in_features
,
out_features
=
hidden_size
,
bias
=
True
)
self
.
linear_1
=
nn
.
Linear
(
in_features
=
in_features
,
out_features
=
hidden_size
,
bias
=
True
)
self
.
act_1
=
nn
.
GELU
(
approximate
=
"tanh"
)
self
.
act_1
=
nn
.
GELU
(
approximate
=
"tanh"
)
self
.
linear_2
=
nn
.
Linear
(
in_features
=
hidden_size
,
out_features
=
hidden_size
,
bias
=
True
)
self
.
linear_2
=
nn
.
Linear
(
in_features
=
hidden_size
,
out_features
=
hidden_size
,
bias
=
True
)
self
.
register_buffer
(
"y_embedding"
,
nn
.
Parameter
(
torch
.
randn
(
num_tokens
,
in_features
)
/
in_features
**
0.5
))
def
forward
(
self
,
caption
,
force_drop_ids
=
None
):
def
forward
(
self
,
caption
):
hidden_states
=
self
.
linear_1
(
caption
)
hidden_states
=
self
.
linear_1
(
caption
)
hidden_states
=
self
.
act_1
(
hidden_states
)
hidden_states
=
self
.
act_1
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
...
...
src/diffusers/models/normalization.py
View file @
3e71a206
...
@@ -20,7 +20,7 @@ import torch.nn as nn
...
@@ -20,7 +20,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
.activations
import
get_activation
from
.activations
import
get_activation
from
.embeddings
import
CombinedTimestepLabelEmbeddings
,
CombinedTimestepSizeEmbeddings
from
.embeddings
import
CombinedTimestepLabelEmbeddings
,
PixArtAlpha
CombinedTimestepSizeEmbeddings
class
AdaLayerNorm
(
nn
.
Module
):
class
AdaLayerNorm
(
nn
.
Module
):
...
@@ -91,7 +91,7 @@ class AdaLayerNormSingle(nn.Module):
...
@@ -91,7 +91,7 @@ class AdaLayerNormSingle(nn.Module):
def
__init__
(
self
,
embedding_dim
:
int
,
use_additional_conditions
:
bool
=
False
):
def
__init__
(
self
,
embedding_dim
:
int
,
use_additional_conditions
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
emb
=
CombinedTimestepSizeEmbeddings
(
self
.
emb
=
PixArtAlpha
CombinedTimestepSizeEmbeddings
(
embedding_dim
,
size_emb_dim
=
embedding_dim
//
3
,
use_additional_conditions
=
use_additional_conditions
embedding_dim
,
size_emb_dim
=
embedding_dim
//
3
,
use_additional_conditions
=
use_additional_conditions
)
)
...
...
src/diffusers/models/transformer_2d.py
View file @
3e71a206
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
...
@@ -22,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from
..models.embeddings
import
ImagePositionalEmbeddings
from
..models.embeddings
import
ImagePositionalEmbeddings
from
..utils
import
USE_PEFT_BACKEND
,
BaseOutput
,
deprecate
,
is_torch_version
from
..utils
import
USE_PEFT_BACKEND
,
BaseOutput
,
deprecate
,
is_torch_version
from
.attention
import
BasicTransformerBlock
from
.attention
import
BasicTransformerBlock
from
.embeddings
import
CaptionProjection
,
PatchEmbed
from
.embeddings
import
PatchEmbed
,
PixArtAlphaTextProjection
from
.lora
import
LoRACompatibleConv
,
LoRACompatibleLinear
from
.lora
import
LoRACompatibleConv
,
LoRACompatibleLinear
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.normalization
import
AdaLayerNormSingle
from
.normalization
import
AdaLayerNormSingle
...
@@ -235,7 +235,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
...
@@ -235,7 +235,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self
.
caption_projection
=
None
self
.
caption_projection
=
None
if
caption_channels
is
not
None
:
if
caption_channels
is
not
None
:
self
.
caption_projection
=
Caption
Projection
(
in_features
=
caption_channels
,
hidden_size
=
inner_dim
)
self
.
caption_projection
=
PixArtAlphaText
Projection
(
in_features
=
caption_channels
,
hidden_size
=
inner_dim
)
self
.
gradient_checkpointing
=
False
self
.
gradient_checkpointing
=
False
...
...
src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
View file @
3e71a206
...
@@ -853,6 +853,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
...
@@ -853,6 +853,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
aspect_ratio
=
torch
.
tensor
([
float
(
height
/
width
)]).
repeat
(
batch_size
*
num_images_per_prompt
,
1
)
aspect_ratio
=
torch
.
tensor
([
float
(
height
/
width
)]).
repeat
(
batch_size
*
num_images_per_prompt
,
1
)
resolution
=
resolution
.
to
(
dtype
=
prompt_embeds
.
dtype
,
device
=
device
)
resolution
=
resolution
.
to
(
dtype
=
prompt_embeds
.
dtype
,
device
=
device
)
aspect_ratio
=
aspect_ratio
.
to
(
dtype
=
prompt_embeds
.
dtype
,
device
=
device
)
aspect_ratio
=
aspect_ratio
.
to
(
dtype
=
prompt_embeds
.
dtype
,
device
=
device
)
if
do_classifier_free_guidance
:
resolution
=
torch
.
cat
([
resolution
,
resolution
],
dim
=
0
)
aspect_ratio
=
torch
.
cat
([
aspect_ratio
,
aspect_ratio
],
dim
=
0
)
added_cond_kwargs
=
{
"resolution"
:
resolution
,
"aspect_ratio"
:
aspect_ratio
}
added_cond_kwargs
=
{
"resolution"
:
resolution
,
"aspect_ratio"
:
aspect_ratio
}
# 7. Denoising loop
# 7. Denoising loop
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment