Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d5acb411
Unverified
Commit
d5acb411
authored
Jul 19, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 19, 2022
Browse files
Finalize ldm (#96)
* upload * make checkpoint work * finalize
parent
6cabc599
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
999 additions
and
135 deletions
+999
-135
src/diffusers/__init__.py
src/diffusers/__init__.py
+9
-1
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+22
-24
src/diffusers/models/unet_conditional.py
src/diffusers/models/unet_conditional.py
+632
-0
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+264
-5
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+64
-99
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+7
-6
No files found.
src/diffusers/__init__.py
View file @
d5acb411
...
@@ -7,7 +7,15 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
...
@@ -7,7 +7,15 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__
=
"0.0.4"
__version__
=
"0.0.4"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models
import
AutoencoderKL
,
NCSNpp
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
from
.models
import
(
AutoencoderKL
,
NCSNpp
,
UNetConditionalModel
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
,
)
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
(
from
.pipelines
import
(
DDIMPipeline
,
DDIMPipeline
,
...
...
src/diffusers/models/__init__.py
View file @
d5acb411
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
# limitations under the License.
# limitations under the License.
from
.unet
import
UNetModel
from
.unet
import
UNetModel
from
.unet_conditional
import
UNetConditionalModel
from
.unet_glide
import
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GlideUNetModel
from
.unet_glide
import
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GlideUNetModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_sde_score_estimation
import
NCSNpp
from
.unet_sde_score_estimation
import
NCSNpp
...
...
src/diffusers/models/attention.py
View file @
d5acb411
...
@@ -42,7 +42,7 @@ class AttentionBlockNew(nn.Module):
...
@@ -42,7 +42,7 @@ class AttentionBlockNew(nn.Module):
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
)
)
self
.
proj_attn
=
nn
.
Linear
(
channels
,
channels
,
1
)
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
-
1
)
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
-
1
)
...
@@ -147,6 +147,8 @@ class SpatialTransformer(nn.Module):
...
@@ -147,6 +147,8 @@ class SpatialTransformer(nn.Module):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
d_head
=
d_head
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
inner_dim
=
n_heads
*
d_head
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
...
@@ -160,7 +162,7 @@ class SpatialTransformer(nn.Module):
...
@@ -160,7 +162,7 @@ class SpatialTransformer(nn.Module):
]
]
)
)
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
)
self
.
proj_out
=
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
context
=
None
):
def
forward
(
self
,
x
,
context
=
None
):
# note: if no context is given, cross-attention defaults to self-attention
# note: if no context is given, cross-attention defaults to self-attention
...
@@ -175,6 +177,12 @@ class SpatialTransformer(nn.Module):
...
@@ -175,6 +177,12 @@ class SpatialTransformer(nn.Module):
x
=
self
.
proj_out
(
x
)
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
return
x
+
x_in
def
set_weight
(
self
,
layer
):
self
.
norm
=
layer
.
norm
self
.
proj_in
=
layer
.
proj_in
self
.
transformer_blocks
=
layer
.
transformer_blocks
self
.
proj_out
=
layer
.
proj_out
class
BasicTransformerBlock
(
nn
.
Module
):
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
...
@@ -270,14 +278,15 @@ class FeedForward(nn.Module):
...
@@ -270,14 +278,15 @@ class FeedForward(nn.Module):
return
self
.
net
(
x
)
return
self
.
net
(
x
)
# TODO(Patrick) - this can and should be removed
# feedforward
def
zero_module
(
module
):
class
GEGLU
(
nn
.
Module
):
"""
def
__init__
(
self
,
dim_in
,
dim_out
):
Zero out the parameters of a module and return it.
super
().
__init__
()
"""
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
def
forward
(
self
,
x
):
return
module
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
...
@@ -298,17 +307,6 @@ def default(val, d):
...
@@ -298,17 +307,6 @@ def default(val, d):
return
d
()
if
isfunction
(
d
)
else
d
return
d
()
if
isfunction
(
d
)
else
d
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
# the main attention block that is used for all models
# the main attention block that is used for all models
class
AttentionBlock
(
nn
.
Module
):
class
AttentionBlock
(
nn
.
Module
):
"""
"""
...
@@ -348,7 +346,7 @@ class AttentionBlock(nn.Module):
...
@@ -348,7 +346,7 @@ class AttentionBlock(nn.Module):
if
encoder_channels
is
not
None
:
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
)
)
self
.
proj
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
overwrite_qkv
=
overwrite_qkv
self
.
overwrite_qkv
=
overwrite_qkv
self
.
overwrite_linear
=
overwrite_linear
self
.
overwrite_linear
=
overwrite_linear
...
@@ -370,7 +368,7 @@ class AttentionBlock(nn.Module):
...
@@ -370,7 +368,7 @@ class AttentionBlock(nn.Module):
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
channels
,
eps
=
1e-6
)
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
channels
,
eps
=
1e-6
)
else
:
else
:
self
.
proj_out
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
)
)
self
.
proj_out
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
set_weights
(
self
)
self
.
set_weights
(
self
)
self
.
is_overwritten
=
False
self
.
is_overwritten
=
False
...
@@ -385,7 +383,7 @@ class AttentionBlock(nn.Module):
...
@@ -385,7 +383,7 @@ class AttentionBlock(nn.Module):
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
zero_module
(
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
)
)
proj_out
=
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
)
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
...
...
src/diffusers/models/unet_conditional.py
0 → 100644
View file @
d5acb411
import
functools
import
math
from
typing
import
Dict
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
,
SpatialTransformer
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
UNetMidBlock2DCrossAttn
,
get_down_block
,
get_up_block
class
Combine
(
nn
.
Module
):
"""Combine information from skip connections."""
def
__init__
(
self
,
dim1
,
dim2
,
method
=
"cat"
):
super
().
__init__
()
# 1x1 convolution with DDPM initialization.
self
.
Conv_0
=
nn
.
Conv2d
(
dim1
,
dim2
,
kernel_size
=
1
,
padding
=
0
)
self
.
method
=
method
# def forward(self, x, y):
# h = self.Conv_0(x)
# if self.method == "cat":
# return torch.cat([h, y], dim=1)
# elif self.method == "sum":
# return h + y
# else:
# raise ValueError(f"Method {self.method} not recognized.")
class
TimestepEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
channel
,
time_embed_dim
,
act_fn
=
"silu"
):
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
channel
,
time_embed_dim
)
self
.
act
=
None
if
act_fn
==
"silu"
:
self
.
act
=
nn
.
SiLU
()
self
.
linear_2
=
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
)
def
forward
(
self
,
sample
):
sample
=
self
.
linear_1
(
sample
)
if
self
.
act
is
not
None
:
sample
=
self
.
act
(
sample
)
sample
=
self
.
linear_2
(
sample
)
return
sample
class
Timesteps
(
nn
.
Module
):
def
__init__
(
self
,
num_channels
,
flip_sin_to_cos
,
downscale_freq_shift
):
super
().
__init__
()
self
.
num_channels
=
num_channels
self
.
flip_sin_to_cos
=
flip_sin_to_cos
self
.
downscale_freq_shift
=
downscale_freq_shift
def
forward
(
self
,
timesteps
):
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
num_channels
,
flip_sin_to_cos
=
self
.
flip_sin_to_cos
,
downscale_freq_shift
=
self
.
downscale_freq_shift
,
)
return
t_emb
class
UNetConditionalModel
(
ModelMixin
,
ConfigMixin
):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
rates at which
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def
__init__
(
self
,
image_size
=
None
,
in_channels
=
4
,
out_channels
=
4
,
num_res_blocks
=
2
,
dropout
=
0
,
block_channels
=
(
320
,
640
,
1280
,
1280
),
down_blocks
=
(
"UNetResCrossAttnDownBlock2D"
,
"UNetResCrossAttnDownBlock2D"
,
"UNetResCrossAttnDownBlock2D"
,
"UNetResDownBlock2D"
,
),
downsample_padding
=
1
,
up_blocks
=
(
"UNetResUpBlock2D"
,
"UNetResCrossAttnUpBlock2D"
,
"UNetResCrossAttnUpBlock2D"
,
"UNetResCrossAttnUpBlock2D"
,
),
resnet_act_fn
=
"silu"
,
resnet_eps
=
1e-5
,
conv_resample
=
True
,
num_head_channels
=
8
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
,
mid_block_scale_factor
=
1
,
center_input_sample
=
False
,
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
# LDM
attention_resolutions
=
(
4
,
2
,
1
),
# DDPM
out_ch
=
None
,
resolution
=
None
,
attn_resolutions
=
None
,
resamp_with_conv
=
None
,
ch_mult
=
None
,
ch
=
None
,
ddpm
=
False
,
# SDE
sde
=
False
,
nf
=
None
,
fir
=
None
,
progressive
=
None
,
progressive_combine
=
None
,
scale_by_sigma
=
None
,
skip_rescale
=
None
,
num_channels
=
None
,
centered
=
False
,
conditional
=
True
,
conv_size
=
3
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
fourier_scale
=
16
,
init_scale
=
0.0
,
progressive_input
=
"input_skip"
,
resnet_num_groups
=
32
,
continuous
=
True
,
ldm
=
False
,
):
super
().
__init__
()
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
block_channels
=
block_channels
,
downsample_padding
=
downsample_padding
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
down_blocks
=
down_blocks
,
up_blocks
=
up_blocks
,
dropout
=
dropout
,
resnet_eps
=
resnet_eps
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
downscale_freq_shift
,
attention_resolutions
=
attention_resolutions
,
attn_resolutions
=
attn_resolutions
,
mid_block_scale_factor
=
mid_block_scale_factor
,
resnet_num_groups
=
resnet_num_groups
,
center_input_sample
=
center_input_sample
,
)
self
.
ldm
=
ldm
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self
.
image_size
=
image_size
time_embed_dim
=
block_channels
[
0
]
*
4
# ======================================
# input
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
# time
self
.
time_steps
=
Timesteps
(
block_channels
[
0
],
flip_sin_to_cos
,
downscale_freq_shift
)
timestep_input_dim
=
block_channels
[
0
]
self
.
time_embedding
=
TimestepEmbedding
(
timestep_input_dim
,
time_embed_dim
)
self
.
downsample_blocks
=
nn
.
ModuleList
([])
self
.
mid
=
None
self
.
upsample_blocks
=
nn
.
ModuleList
([])
# down
output_channel
=
block_channels
[
0
]
for
i
,
down_block_type
in
enumerate
(
down_blocks
):
input_channel
=
output_channel
output_channel
=
block_channels
[
i
]
is_final_block
=
i
==
len
(
block_channels
)
-
1
down_block
=
get_down_block
(
down_block_type
,
num_layers
=
num_res_blocks
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
temb_channels
=
time_embed_dim
,
add_downsample
=
not
is_final_block
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
num_head_channels
,
downsample_padding
=
downsample_padding
,
)
self
.
downsample_blocks
.
append
(
down_block
)
# mid
self
.
mid
=
UNetMidBlock2DCrossAttn
(
in_channels
=
block_channels
[
-
1
],
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
output_scale_factor
=
mid_block_scale_factor
,
resnet_time_scale_shift
=
"default"
,
attn_num_head_channels
=
num_head_channels
,
resnet_groups
=
resnet_num_groups
,
)
# up
reversed_block_channels
=
list
(
reversed
(
block_channels
))
output_channel
=
reversed_block_channels
[
0
]
for
i
,
up_block_type
in
enumerate
(
up_blocks
):
prev_output_channel
=
output_channel
output_channel
=
reversed_block_channels
[
i
]
input_channel
=
reversed_block_channels
[
min
(
i
+
1
,
len
(
block_channels
)
-
1
)]
is_final_block
=
i
==
len
(
block_channels
)
-
1
up_block
=
get_up_block
(
up_block_type
,
num_layers
=
num_res_blocks
+
1
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
time_embed_dim
,
add_upsample
=
not
is_final_block
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
num_head_channels
,
)
self
.
upsample_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
# out
num_groups_out
=
resnet_num_groups
if
resnet_num_groups
is
not
None
else
min
(
block_channels
[
0
]
//
4
,
32
)
self
.
conv_norm_out
=
nn
.
GroupNorm
(
num_channels
=
block_channels
[
0
],
num_groups
=
num_groups_out
,
eps
=
resnet_eps
)
self
.
conv_act
=
nn
.
SiLU
()
self
.
conv_out
=
nn
.
Conv2d
(
block_channels
[
0
],
out_channels
,
3
,
padding
=
1
)
# ======================== Out ====================
# =========== TO DELETE AFTER CONVERSION ==========
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self
.
is_overwritten
=
False
if
ldm
:
num_heads
=
8
num_head_channels
=
-
1
transformer_depth
=
1
use_spatial_transformer
=
True
context_dim
=
1280
legacy
=
False
model_channels
=
block_channels
[
0
]
channel_mult
=
tuple
([
x
//
model_channels
for
x
in
block_channels
])
self
.
init_for_ldm
(
in_channels
,
model_channels
,
channel_mult
,
num_res_blocks
,
dropout
,
time_embed_dim
,
attention_resolutions
,
num_head_channels
,
num_heads
,
legacy
,
False
,
transformer_depth
,
context_dim
,
conv_resample
,
out_channels
,
)
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
torch
.
Tensor
,
float
,
int
],
encoder_hidden_states
:
torch
.
Tensor
,
)
->
Dict
[
str
,
torch
.
FloatTensor
]:
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
if
not
self
.
is_overwritten
:
self
.
set_weights
()
if
self
.
config
.
center_input_sample
:
sample
=
2
*
sample
-
1.0
# 1. time
timesteps
=
timestep
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
elif
torch
.
is_tensor
(
timesteps
)
and
len
(
timesteps
.
shape
)
==
0
:
timesteps
=
timesteps
[
None
].
to
(
sample
.
device
)
t_emb
=
self
.
time_steps
(
timesteps
)
emb
=
self
.
time_embedding
(
t_emb
)
# 2. pre-process
skip_sample
=
sample
sample
=
self
.
conv_in
(
sample
)
# 3. down
down_block_res_samples
=
(
sample
,)
for
downsample_block
in
self
.
downsample_blocks
:
if
hasattr
(
downsample_block
,
"attentions"
)
and
downsample_block
.
attentions
is
not
None
:
sample
,
res_samples
=
downsample_block
(
hidden_states
=
sample
,
temb
=
emb
,
encoder_hidden_states
=
encoder_hidden_states
)
else
:
sample
,
res_samples
=
downsample_block
(
hidden_states
=
sample
,
temb
=
emb
)
down_block_res_samples
+=
res_samples
# 4. mid
sample
=
self
.
mid
(
sample
,
emb
,
encoder_hidden_states
=
encoder_hidden_states
)
# 5. up
skip_sample
=
None
for
upsample_block
in
self
.
upsample_blocks
:
res_samples
=
down_block_res_samples
[
-
len
(
upsample_block
.
resnets
)
:]
down_block_res_samples
=
down_block_res_samples
[:
-
len
(
upsample_block
.
resnets
)]
if
hasattr
(
upsample_block
,
"attentions"
)
and
upsample_block
.
attentions
is
not
None
:
sample
=
upsample_block
(
hidden_states
=
sample
,
temb
=
emb
,
res_hidden_states_tuple
=
res_samples
,
encoder_hidden_states
=
encoder_hidden_states
,
)
else
:
sample
=
upsample_block
(
hidden_states
=
sample
,
temb
=
emb
,
res_hidden_states_tuple
=
res_samples
)
# 6. post-process
sample
=
self
.
conv_norm_out
(
sample
)
sample
=
self
.
conv_act
(
sample
)
sample
=
self
.
conv_out
(
sample
)
output
=
{
"sample"
:
sample
}
return
output
# !!!IMPORTANT - ALL OF THE FOLLOWING CODE WILL BE DELETED AT RELEASE TIME AND SHOULD NOT BE TAKEN INTO CONSIDERATION WHEN EVALUATING THE API ###
# =================================================================================================================================================
def
set_weights
(
self
):
self
.
is_overwritten
=
True
if
self
.
ldm
:
self
.
time_embedding
.
linear_1
.
weight
.
data
=
self
.
time_embed
[
0
].
weight
.
data
self
.
time_embedding
.
linear_1
.
bias
.
data
=
self
.
time_embed
[
0
].
bias
.
data
self
.
time_embedding
.
linear_2
.
weight
.
data
=
self
.
time_embed
[
2
].
weight
.
data
self
.
time_embedding
.
linear_2
.
bias
.
data
=
self
.
time_embed
[
2
].
bias
.
data
self
.
conv_in
.
weight
.
data
=
self
.
input_blocks
[
0
][
0
].
weight
.
data
self
.
conv_in
.
bias
.
data
=
self
.
input_blocks
[
0
][
0
].
bias
.
data
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
for
i
,
input_layer
in
enumerate
(
self
.
input_blocks
[
1
:]):
block_id
=
i
//
(
self
.
config
.
num_res_blocks
+
1
)
layer_in_block_id
=
i
%
(
self
.
config
.
num_res_blocks
+
1
)
if
layer_in_block_id
==
2
:
self
.
downsample_blocks
[
block_id
].
downsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
0
].
op
.
weight
.
data
self
.
downsample_blocks
[
block_id
].
downsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
0
].
op
.
bias
.
data
elif
len
(
input_layer
)
>
1
:
self
.
downsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
downsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
else
:
self
.
downsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
mid
.
resnets
[
0
].
set_weight
(
self
.
middle_block
[
0
])
self
.
mid
.
resnets
[
1
].
set_weight
(
self
.
middle_block
[
2
])
self
.
mid
.
attentions
[
0
].
set_weight
(
self
.
middle_block
[
1
])
for
i
,
input_layer
in
enumerate
(
self
.
output_blocks
):
block_id
=
i
//
(
self
.
config
.
num_res_blocks
+
1
)
layer_in_block_id
=
i
%
(
self
.
config
.
num_res_blocks
+
1
)
if
len
(
input_layer
)
>
2
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
2
].
conv
.
weight
.
data
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
2
].
conv
.
bias
.
data
elif
len
(
input_layer
)
>
1
and
"Upsample2D"
in
input_layer
[
1
].
__class__
.
__name__
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
1
].
conv
.
weight
.
data
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
1
].
conv
.
bias
.
data
elif
len
(
input_layer
)
>
1
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
else
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
conv_norm_out
.
weight
.
data
=
self
.
out
[
0
].
weight
.
data
self
.
conv_norm_out
.
bias
.
data
=
self
.
out
[
0
].
bias
.
data
self
.
conv_out
.
weight
.
data
=
self
.
out
[
2
].
weight
.
data
self
.
conv_out
.
bias
.
data
=
self
.
out
[
2
].
bias
.
data
self
.
remove_ldm
()
def
init_for_ldm
(
self
,
in_channels
,
model_channels
,
channel_mult
,
num_res_blocks
,
dropout
,
time_embed_dim
,
attention_resolutions
,
num_head_channels
,
num_heads
,
legacy
,
use_spatial_transformer
,
transformer_depth
,
context_dim
,
conv_resample
,
out_channels
,
):
# TODO(PVP) - delete after weight conversion
class
TimestepEmbedSequential
(
nn
.
Sequential
):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
pass
# TODO(PVP) - delete after weight conversion
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
)
dims
=
2
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
)
self
.
_feature_size
=
model_channels
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
ResnetBlock2D
(
in_channels
=
ch
,
out_channels
=
mult
*
model_channels
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
# num_heads = 1
dim_head
=
num_head_channels
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
),
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
Downsample2D
(
ch
,
use_conv
=
conv_resample
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
ds
*=
2
self
.
_feature_size
+=
ch
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
# num_heads = 1
dim_head
=
num_head_channels
if
dim_head
<
0
:
dim_head
=
None
# TODO(Patrick) - delete after weight conversion
# init to be able to overwrite `self.mid`
self
.
middle_block
=
TimestepEmbedSequential
(
ResnetBlock2D
(
in_channels
=
ch
,
out_channels
=
None
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
),
ResnetBlock2D
(
in_channels
=
ch
,
out_channels
=
None
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
)
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
for
level
,
mult
in
list
(
enumerate
(
channel_mult
))[::
-
1
]:
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
layers
=
[
ResnetBlock2D
(
in_channels
=
ch
+
ich
,
out_channels
=
model_channels
*
mult
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
]
ch
=
model_channels
*
mult
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
# num_heads = 1
dim_head
=
num_head_channels
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
)
)
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
layers
.
append
(
Upsample2D
(
ch
,
use_conv
=
conv_resample
,
out_channels
=
out_ch
))
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
num_channels
=
model_channels
,
num_groups
=
32
,
eps
=
1e-5
),
nn
.
SiLU
(),
nn
.
Conv2d
(
model_channels
,
out_channels
,
3
,
padding
=
1
),
)
def
remove_ldm
(
self
):
del
self
.
time_embed
del
self
.
input_blocks
del
self
.
middle_block
del
self
.
output_blocks
del
self
.
out
src/diffusers/models/unet_new.py
View file @
d5acb411
...
@@ -17,7 +17,7 @@ import numpy as np
...
@@ -17,7 +17,7 @@ import numpy as np
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
.attention
import
AttentionBlockNew
from
.attention
import
AttentionBlockNew
,
SpatialTransformer
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock
,
Upsample2D
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock
,
Upsample2D
...
@@ -56,6 +56,18 @@ def get_down_block(
...
@@ -56,6 +56,18 @@ def get_down_block(
downsample_padding
=
downsample_padding
,
downsample_padding
=
downsample_padding
,
attn_num_head_channels
=
attn_num_head_channels
,
attn_num_head_channels
=
attn_num_head_channels
,
)
)
elif
down_block_type
==
"UNetResCrossAttnDownBlock2D"
:
return
UNetResCrossAttnDownBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
downsample_padding
=
downsample_padding
,
attn_num_head_channels
=
attn_num_head_channels
,
)
elif
down_block_type
==
"UNetResSkipDownBlock2D"
:
elif
down_block_type
==
"UNetResSkipDownBlock2D"
:
return
UNetResSkipDownBlock2D
(
return
UNetResSkipDownBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
...
@@ -104,6 +116,18 @@ def get_up_block(
...
@@ -104,6 +116,18 @@ def get_up_block(
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_act_fn
=
resnet_act_fn
,
)
)
elif
up_block_type
==
"UNetResCrossAttnUpBlock2D"
:
return
UNetResCrossAttnUpBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
attn_num_head_channels
,
)
elif
up_block_type
==
"UNetResAttnUpBlock2D"
:
elif
up_block_type
==
"UNetResAttnUpBlock2D"
:
return
UNetResAttnUpBlock2D
(
return
UNetResAttnUpBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
...
@@ -221,6 +245,83 @@ class UNetMidBlock2D(nn.Module):
...
@@ -221,6 +245,83 @@ class UNetMidBlock2D(nn.Module):
return
hidden_states
return
hidden_states
class
UNetMidBlock2DCrossAttn
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
cross_attention_dim
=
1280
,
**
kwargs
,
):
super
().
__init__
()
self
.
attention_type
=
attention_type
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
# there is always at least one resnet
resnets
=
[
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
]
attentions
=
[]
for
_
in
range
(
num_layers
):
attentions
.
append
(
SpatialTransformer
(
in_channels
,
attn_num_head_channels
,
in_channels
//
attn_num_head_channels
,
depth
=
1
,
context_dim
=
cross_attention_dim
,
)
)
resnets
.
append
(
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_hidden_states
=
None
):
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
hidden_states
=
attn
(
hidden_states
,
encoder_hidden_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
return
hidden_states
class
UNetResAttnDownBlock2D
(
nn
.
Module
):
class
UNetResAttnDownBlock2D
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -302,6 +403,88 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -302,6 +403,88 @@ class UNetResAttnDownBlock2D(nn.Module):
return
hidden_states
,
output_states
return
hidden_states
,
output_states
class
UNetResCrossAttnDownBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
cross_attention_dim
=
1280
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
downsample_padding
=
1
,
add_downsample
=
True
,
):
super
().
__init__
()
resnets
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
attentions
.
append
(
SpatialTransformer
(
out_channels
,
attn_num_head_channels
,
out_channels
//
attn_num_head_channels
,
depth
=
1
,
context_dim
=
cross_attention_dim
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
else
:
self
.
downsamplers
=
None
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_hidden_states
=
None
):
output_states
=
()
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
context
=
encoder_hidden_states
)
output_states
+=
(
hidden_states
,)
if
self
.
downsamplers
is
not
None
:
for
downsampler
in
self
.
downsamplers
:
hidden_states
=
downsampler
(
hidden_states
)
output_states
+=
(
hidden_states
,)
return
hidden_states
,
output_states
class
UNetResDownBlock2D
(
nn
.
Module
):
class
UNetResDownBlock2D
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -618,6 +801,86 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -618,6 +801,86 @@ class UNetResAttnUpBlock2D(nn.Module):
return
hidden_states
return
hidden_states
class
UNetResCrossAttnUpBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
prev_output_channel
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
cross_attention_dim
=
1280
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
downsample_padding
=
1
,
add_upsample
=
True
,
):
super
().
__init__
()
resnets
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
attentions
.
append
(
SpatialTransformer
(
out_channels
,
attn_num_head_channels
,
out_channels
//
attn_num_head_channels
,
depth
=
1
,
context_dim
=
cross_attention_dim
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
out_channels
,
use_conv
=
True
,
out_channels
=
out_channels
)])
else
:
self
.
upsamplers
=
None
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
encoder_hidden_states
=
None
):
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
context
=
encoder_hidden_states
)
if
self
.
upsamplers
is
not
None
:
for
upsampler
in
self
.
upsamplers
:
hidden_states
=
upsampler
(
hidden_states
)
return
hidden_states
class
UNetResUpBlock2D
(
nn
.
Module
):
class
UNetResUpBlock2D
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -765,8 +1028,6 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
...
@@ -765,8 +1028,6 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
self
.
act
=
None
self
.
act
=
None
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
skip_sample
=
None
):
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
skip_sample
=
None
):
output_states
=
()
for
resnet
in
self
.
resnets
:
for
resnet
in
self
.
resnets
:
# pop res hidden states
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
...
@@ -864,8 +1125,6 @@ class UNetResSkipUpBlock2D(nn.Module):
...
@@ -864,8 +1125,6 @@ class UNetResSkipUpBlock2D(nn.Module):
self
.
act
=
None
self
.
act
=
None
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
skip_sample
=
None
):
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
skip_sample
=
None
):
output_states
=
()
for
resnet
in
self
.
resnets
:
for
resnet
in
self
.
resnets
:
# pop res hidden states
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
...
...
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
d5acb411
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.utils.checkpoint
import
torch.utils.checkpoint
...
@@ -15,6 +14,69 @@ from transformers.utils import logging
...
@@ -15,6 +14,69 @@ from transformers.utils import logging
from
...pipeline_utils
import
DiffusionPipeline
from
...pipeline_utils
import
DiffusionPipeline
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
scheduler
):
super
().
__init__
()
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
,
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
# get unconditional embeddings for classifier free guidence
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)
# get text embedding
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
).
to
(
torch_device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
self
.
scheduler
.
timesteps
):
# 1. predict noise residual
pred_noise_t
=
self
.
unet
(
image
,
t
,
encoder_hidden_states
=
text_embedding
)
if
isinstance
(
pred_noise_t
,
dict
):
pred_noise_t
=
pred_noise_t
[
"sample"
]
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image
=
self
.
scheduler
.
step
(
pred_noise_t
,
t
,
image
,
eta
)[
"prev_sample"
]
# scale and decode image with vae
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
################################################################################
################################################################################
# Code for the text transformer model
# Code for the text transformer model
################################################################################
################################################################################
...
@@ -541,101 +603,4 @@ class LDMBertModel(LDMBertPreTrainedModel):
...
@@ -541,101 +603,4 @@ class LDMBertModel(LDMBertPreTrainedModel):
return_dict
=
return_dict
,
return_dict
=
return_dict
,
)
)
sequence_output
=
outputs
[
0
]
sequence_output
=
outputs
[
0
]
return
sequence_output
return
sequence_output
\ No newline at end of file
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
scheduler
):
super
().
__init__
()
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
,
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
# get unconditional embeddings for classifier free guidence
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)
# get text embedding
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)
num_trained_timesteps
=
self
.
scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
).
to
(
torch_device
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# guidance_scale of 1 means no guidance
if
guidance_scale
==
1.0
:
image_in
=
image
context
=
text_embedding
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
else
:
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
image_in
=
torch
.
cat
([
image
]
*
2
)
context
=
torch
.
cat
([
uncond_embeddings
,
text_embedding
])
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
# 1. predict noise residual
pred_noise_t
=
self
.
unet
(
image_in
,
timesteps
,
context
=
context
)
# perform guidance
if
guidance_scale
!=
1.0
:
pred_noise_t_uncond
,
pred_noise_t
=
pred_noise_t
.
chunk
(
2
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
# scale and decode image with vae
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
tests/test_modeling_utils.py
View file @
d5acb411
...
@@ -40,14 +40,17 @@ from diffusers import (
...
@@ -40,14 +40,17 @@ from diffusers import (
ScoreSdeVeScheduler
,
ScoreSdeVeScheduler
,
ScoreSdeVpPipeline
,
ScoreSdeVpPipeline
,
ScoreSdeVpScheduler
,
ScoreSdeVpScheduler
,
UNetConditionalModel
,
UNetLDMModel
,
UNetLDMModel
,
UNetUnconditionalModel
,
UNetUnconditionalModel
,
VQModel
,
VQModel
,
)
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion
import
LDMBertModel
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
from
transformers
import
BertTokenizer
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
...
@@ -827,7 +830,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -827,7 +830,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
class
Auto
E
ncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
Auto
e
ncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
AutoencoderKL
model_class
=
AutoencoderKL
@
property
@
property
...
@@ -1026,10 +1029,8 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -1026,10 +1029,8 @@ class PipelineTesterMixin(unittest.TestCase):
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
@
slow
@
unittest
.
skip
(
"Skipping for now as it takes too long"
)
def
test_ldm_text2img
(
self
):
def
test_ldm_text2img
(
self
):
model_id
=
"fusing/latent-diffusion-text2im-large"
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
"/home/patrick/latent-diffusion-text2im-large"
)
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
model_id
)
prompt
=
"A painting of a squirrel eating a burger"
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -1043,8 +1044,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -1043,8 +1044,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ldm_text2img_fast
(
self
):
def
test_ldm_text2img_fast
(
self
):
model_id
=
"fusing/latent-diffusion-text2im-large"
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
"/home/patrick/latent-diffusion-text2im-large"
)
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
model_id
)
prompt
=
"A painting of a squirrel eating a burger"
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -1074,6 +1074,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -1074,6 +1074,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_score_sde_ve_pipeline
(
self
):
def
test_score_sde_ve_pipeline
(
self
):
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ffhq_ncsnpp"
,
sde
=
True
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ffhq_ncsnpp"
,
sde
=
True
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"google/ffhq_ncsnpp"
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment