Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
d5acb411
Unverified
Commit
d5acb411
authored
Jul 19, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 19, 2022
Browse files
Finalize ldm (#96)
* upload * make checkpoint work * finalize
parent
6cabc599
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
999 additions
and
135 deletions
+999
-135
src/diffusers/__init__.py
src/diffusers/__init__.py
+9
-1
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+22
-24
src/diffusers/models/unet_conditional.py
src/diffusers/models/unet_conditional.py
+632
-0
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+264
-5
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+64
-99
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+7
-6
No files found.
src/diffusers/__init__.py
View file @
d5acb411
...
@@ -7,7 +7,15 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
...
@@ -7,7 +7,15 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__
=
"0.0.4"
__version__
=
"0.0.4"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models
import
AutoencoderKL
,
NCSNpp
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
from
.models
import
(
AutoencoderKL
,
NCSNpp
,
UNetConditionalModel
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
,
)
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
(
from
.pipelines
import
(
DDIMPipeline
,
DDIMPipeline
,
...
...
src/diffusers/models/__init__.py
View file @
d5acb411
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
# limitations under the License.
# limitations under the License.
from
.unet
import
UNetModel
from
.unet
import
UNetModel
from
.unet_conditional
import
UNetConditionalModel
from
.unet_glide
import
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GlideUNetModel
from
.unet_glide
import
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GlideUNetModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_sde_score_estimation
import
NCSNpp
from
.unet_sde_score_estimation
import
NCSNpp
...
...
src/diffusers/models/attention.py
View file @
d5acb411
...
@@ -42,7 +42,7 @@ class AttentionBlockNew(nn.Module):
...
@@ -42,7 +42,7 @@ class AttentionBlockNew(nn.Module):
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
)
)
self
.
proj_attn
=
nn
.
Linear
(
channels
,
channels
,
1
)
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
-
1
)
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
-
1
)
...
@@ -147,6 +147,8 @@ class SpatialTransformer(nn.Module):
...
@@ -147,6 +147,8 @@ class SpatialTransformer(nn.Module):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
d_head
=
d_head
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
inner_dim
=
n_heads
*
d_head
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
...
@@ -160,7 +162,7 @@ class SpatialTransformer(nn.Module):
...
@@ -160,7 +162,7 @@ class SpatialTransformer(nn.Module):
]
]
)
)
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
)
self
.
proj_out
=
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
context
=
None
):
def
forward
(
self
,
x
,
context
=
None
):
# note: if no context is given, cross-attention defaults to self-attention
# note: if no context is given, cross-attention defaults to self-attention
...
@@ -175,6 +177,12 @@ class SpatialTransformer(nn.Module):
...
@@ -175,6 +177,12 @@ class SpatialTransformer(nn.Module):
x
=
self
.
proj_out
(
x
)
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
return
x
+
x_in
def
set_weight
(
self
,
layer
):
self
.
norm
=
layer
.
norm
self
.
proj_in
=
layer
.
proj_in
self
.
transformer_blocks
=
layer
.
transformer_blocks
self
.
proj_out
=
layer
.
proj_out
class
BasicTransformerBlock
(
nn
.
Module
):
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
...
@@ -270,14 +278,15 @@ class FeedForward(nn.Module):
...
@@ -270,14 +278,15 @@ class FeedForward(nn.Module):
return
self
.
net
(
x
)
return
self
.
net
(
x
)
# TODO(Patrick) - this can and should be removed
# feedforward
def
zero_module
(
module
):
class
GEGLU
(
nn
.
Module
):
"""
def
__init__
(
self
,
dim_in
,
dim_out
):
Zero out the parameters of a module and return it.
super
().
__init__
()
"""
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
def
forward
(
self
,
x
):
return
module
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
...
@@ -298,17 +307,6 @@ def default(val, d):
...
@@ -298,17 +307,6 @@ def default(val, d):
return
d
()
if
isfunction
(
d
)
else
d
return
d
()
if
isfunction
(
d
)
else
d
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
# the main attention block that is used for all models
# the main attention block that is used for all models
class
AttentionBlock
(
nn
.
Module
):
class
AttentionBlock
(
nn
.
Module
):
"""
"""
...
@@ -348,7 +346,7 @@ class AttentionBlock(nn.Module):
...
@@ -348,7 +346,7 @@ class AttentionBlock(nn.Module):
if
encoder_channels
is
not
None
:
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
)
)
self
.
proj
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
overwrite_qkv
=
overwrite_qkv
self
.
overwrite_qkv
=
overwrite_qkv
self
.
overwrite_linear
=
overwrite_linear
self
.
overwrite_linear
=
overwrite_linear
...
@@ -370,7 +368,7 @@ class AttentionBlock(nn.Module):
...
@@ -370,7 +368,7 @@ class AttentionBlock(nn.Module):
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
channels
,
eps
=
1e-6
)
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
channels
,
eps
=
1e-6
)
else
:
else
:
self
.
proj_out
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
)
)
self
.
proj_out
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
set_weights
(
self
)
self
.
set_weights
(
self
)
self
.
is_overwritten
=
False
self
.
is_overwritten
=
False
...
@@ -385,7 +383,7 @@ class AttentionBlock(nn.Module):
...
@@ -385,7 +383,7 @@ class AttentionBlock(nn.Module):
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
zero_module
(
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
)
)
proj_out
=
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
)
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
...
...
src/diffusers/models/unet_conditional.py
0 → 100644
View file @
d5acb411
import
functools
import
math
from
typing
import
Dict
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
,
SpatialTransformer
from
.embeddings
import
GaussianFourierProjection
,
get_timestep_embedding
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock2D
,
Upsample2D
from
.unet_new
import
UNetMidBlock2DCrossAttn
,
get_down_block
,
get_up_block
class
Combine
(
nn
.
Module
):
"""Combine information from skip connections."""
def
__init__
(
self
,
dim1
,
dim2
,
method
=
"cat"
):
super
().
__init__
()
# 1x1 convolution with DDPM initialization.
self
.
Conv_0
=
nn
.
Conv2d
(
dim1
,
dim2
,
kernel_size
=
1
,
padding
=
0
)
self
.
method
=
method
# def forward(self, x, y):
# h = self.Conv_0(x)
# if self.method == "cat":
# return torch.cat([h, y], dim=1)
# elif self.method == "sum":
# return h + y
# else:
# raise ValueError(f"Method {self.method} not recognized.")
class
TimestepEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
channel
,
time_embed_dim
,
act_fn
=
"silu"
):
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
channel
,
time_embed_dim
)
self
.
act
=
None
if
act_fn
==
"silu"
:
self
.
act
=
nn
.
SiLU
()
self
.
linear_2
=
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
)
def
forward
(
self
,
sample
):
sample
=
self
.
linear_1
(
sample
)
if
self
.
act
is
not
None
:
sample
=
self
.
act
(
sample
)
sample
=
self
.
linear_2
(
sample
)
return
sample
class
Timesteps
(
nn
.
Module
):
def
__init__
(
self
,
num_channels
,
flip_sin_to_cos
,
downscale_freq_shift
):
super
().
__init__
()
self
.
num_channels
=
num_channels
self
.
flip_sin_to_cos
=
flip_sin_to_cos
self
.
downscale_freq_shift
=
downscale_freq_shift
def
forward
(
self
,
timesteps
):
t_emb
=
get_timestep_embedding
(
timesteps
,
self
.
num_channels
,
flip_sin_to_cos
=
self
.
flip_sin_to_cos
,
downscale_freq_shift
=
self
.
downscale_freq_shift
,
)
return
t_emb
class
UNetConditionalModel
(
ModelMixin
,
ConfigMixin
):
"""
The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param
model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param
num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample
rates at which
attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x
downsampling, attention will be used.
:param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param
conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this
model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention
heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks
for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def
__init__
(
self
,
image_size
=
None
,
in_channels
=
4
,
out_channels
=
4
,
num_res_blocks
=
2
,
dropout
=
0
,
block_channels
=
(
320
,
640
,
1280
,
1280
),
down_blocks
=
(
"UNetResCrossAttnDownBlock2D"
,
"UNetResCrossAttnDownBlock2D"
,
"UNetResCrossAttnDownBlock2D"
,
"UNetResDownBlock2D"
,
),
downsample_padding
=
1
,
up_blocks
=
(
"UNetResUpBlock2D"
,
"UNetResCrossAttnUpBlock2D"
,
"UNetResCrossAttnUpBlock2D"
,
"UNetResCrossAttnUpBlock2D"
,
),
resnet_act_fn
=
"silu"
,
resnet_eps
=
1e-5
,
conv_resample
=
True
,
num_head_channels
=
8
,
flip_sin_to_cos
=
True
,
downscale_freq_shift
=
0
,
mid_block_scale_factor
=
1
,
center_input_sample
=
False
,
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
# LDM
attention_resolutions
=
(
4
,
2
,
1
),
# DDPM
out_ch
=
None
,
resolution
=
None
,
attn_resolutions
=
None
,
resamp_with_conv
=
None
,
ch_mult
=
None
,
ch
=
None
,
ddpm
=
False
,
# SDE
sde
=
False
,
nf
=
None
,
fir
=
None
,
progressive
=
None
,
progressive_combine
=
None
,
scale_by_sigma
=
None
,
skip_rescale
=
None
,
num_channels
=
None
,
centered
=
False
,
conditional
=
True
,
conv_size
=
3
,
fir_kernel
=
(
1
,
3
,
3
,
1
),
fourier_scale
=
16
,
init_scale
=
0.0
,
progressive_input
=
"input_skip"
,
resnet_num_groups
=
32
,
continuous
=
True
,
ldm
=
False
,
):
super
().
__init__
()
# register all __init__ params to be accessible via `self.config.<...>`
# should probably be automated down the road as this is pure boiler plate code
self
.
register_to_config
(
image_size
=
image_size
,
in_channels
=
in_channels
,
block_channels
=
block_channels
,
downsample_padding
=
downsample_padding
,
out_channels
=
out_channels
,
num_res_blocks
=
num_res_blocks
,
down_blocks
=
down_blocks
,
up_blocks
=
up_blocks
,
dropout
=
dropout
,
resnet_eps
=
resnet_eps
,
conv_resample
=
conv_resample
,
num_head_channels
=
num_head_channels
,
flip_sin_to_cos
=
flip_sin_to_cos
,
downscale_freq_shift
=
downscale_freq_shift
,
attention_resolutions
=
attention_resolutions
,
attn_resolutions
=
attn_resolutions
,
mid_block_scale_factor
=
mid_block_scale_factor
,
resnet_num_groups
=
resnet_num_groups
,
center_input_sample
=
center_input_sample
,
)
self
.
ldm
=
ldm
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self
.
image_size
=
image_size
time_embed_dim
=
block_channels
[
0
]
*
4
# ======================================
# input
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
block_channels
[
0
],
kernel_size
=
3
,
padding
=
(
1
,
1
))
# time
self
.
time_steps
=
Timesteps
(
block_channels
[
0
],
flip_sin_to_cos
,
downscale_freq_shift
)
timestep_input_dim
=
block_channels
[
0
]
self
.
time_embedding
=
TimestepEmbedding
(
timestep_input_dim
,
time_embed_dim
)
self
.
downsample_blocks
=
nn
.
ModuleList
([])
self
.
mid
=
None
self
.
upsample_blocks
=
nn
.
ModuleList
([])
# down
output_channel
=
block_channels
[
0
]
for
i
,
down_block_type
in
enumerate
(
down_blocks
):
input_channel
=
output_channel
output_channel
=
block_channels
[
i
]
is_final_block
=
i
==
len
(
block_channels
)
-
1
down_block
=
get_down_block
(
down_block_type
,
num_layers
=
num_res_blocks
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
temb_channels
=
time_embed_dim
,
add_downsample
=
not
is_final_block
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
num_head_channels
,
downsample_padding
=
downsample_padding
,
)
self
.
downsample_blocks
.
append
(
down_block
)
# mid
self
.
mid
=
UNetMidBlock2DCrossAttn
(
in_channels
=
block_channels
[
-
1
],
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
output_scale_factor
=
mid_block_scale_factor
,
resnet_time_scale_shift
=
"default"
,
attn_num_head_channels
=
num_head_channels
,
resnet_groups
=
resnet_num_groups
,
)
# up
reversed_block_channels
=
list
(
reversed
(
block_channels
))
output_channel
=
reversed_block_channels
[
0
]
for
i
,
up_block_type
in
enumerate
(
up_blocks
):
prev_output_channel
=
output_channel
output_channel
=
reversed_block_channels
[
i
]
input_channel
=
reversed_block_channels
[
min
(
i
+
1
,
len
(
block_channels
)
-
1
)]
is_final_block
=
i
==
len
(
block_channels
)
-
1
up_block
=
get_up_block
(
up_block_type
,
num_layers
=
num_res_blocks
+
1
,
in_channels
=
input_channel
,
out_channels
=
output_channel
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
time_embed_dim
,
add_upsample
=
not
is_final_block
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
num_head_channels
,
)
self
.
upsample_blocks
.
append
(
up_block
)
prev_output_channel
=
output_channel
# out
num_groups_out
=
resnet_num_groups
if
resnet_num_groups
is
not
None
else
min
(
block_channels
[
0
]
//
4
,
32
)
self
.
conv_norm_out
=
nn
.
GroupNorm
(
num_channels
=
block_channels
[
0
],
num_groups
=
num_groups_out
,
eps
=
resnet_eps
)
self
.
conv_act
=
nn
.
SiLU
()
self
.
conv_out
=
nn
.
Conv2d
(
block_channels
[
0
],
out_channels
,
3
,
padding
=
1
)
# ======================== Out ====================
# =========== TO DELETE AFTER CONVERSION ==========
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
self
.
is_overwritten
=
False
if
ldm
:
num_heads
=
8
num_head_channels
=
-
1
transformer_depth
=
1
use_spatial_transformer
=
True
context_dim
=
1280
legacy
=
False
model_channels
=
block_channels
[
0
]
channel_mult
=
tuple
([
x
//
model_channels
for
x
in
block_channels
])
self
.
init_for_ldm
(
in_channels
,
model_channels
,
channel_mult
,
num_res_blocks
,
dropout
,
time_embed_dim
,
attention_resolutions
,
num_head_channels
,
num_heads
,
legacy
,
False
,
transformer_depth
,
context_dim
,
conv_resample
,
out_channels
,
)
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
torch
.
Tensor
,
float
,
int
],
encoder_hidden_states
:
torch
.
Tensor
,
)
->
Dict
[
str
,
torch
.
FloatTensor
]:
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# ======================================
if
not
self
.
is_overwritten
:
self
.
set_weights
()
if
self
.
config
.
center_input_sample
:
sample
=
2
*
sample
-
1.0
# 1. time
timesteps
=
timestep
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
elif
torch
.
is_tensor
(
timesteps
)
and
len
(
timesteps
.
shape
)
==
0
:
timesteps
=
timesteps
[
None
].
to
(
sample
.
device
)
t_emb
=
self
.
time_steps
(
timesteps
)
emb
=
self
.
time_embedding
(
t_emb
)
# 2. pre-process
skip_sample
=
sample
sample
=
self
.
conv_in
(
sample
)
# 3. down
down_block_res_samples
=
(
sample
,)
for
downsample_block
in
self
.
downsample_blocks
:
if
hasattr
(
downsample_block
,
"attentions"
)
and
downsample_block
.
attentions
is
not
None
:
sample
,
res_samples
=
downsample_block
(
hidden_states
=
sample
,
temb
=
emb
,
encoder_hidden_states
=
encoder_hidden_states
)
else
:
sample
,
res_samples
=
downsample_block
(
hidden_states
=
sample
,
temb
=
emb
)
down_block_res_samples
+=
res_samples
# 4. mid
sample
=
self
.
mid
(
sample
,
emb
,
encoder_hidden_states
=
encoder_hidden_states
)
# 5. up
skip_sample
=
None
for
upsample_block
in
self
.
upsample_blocks
:
res_samples
=
down_block_res_samples
[
-
len
(
upsample_block
.
resnets
)
:]
down_block_res_samples
=
down_block_res_samples
[:
-
len
(
upsample_block
.
resnets
)]
if
hasattr
(
upsample_block
,
"attentions"
)
and
upsample_block
.
attentions
is
not
None
:
sample
=
upsample_block
(
hidden_states
=
sample
,
temb
=
emb
,
res_hidden_states_tuple
=
res_samples
,
encoder_hidden_states
=
encoder_hidden_states
,
)
else
:
sample
=
upsample_block
(
hidden_states
=
sample
,
temb
=
emb
,
res_hidden_states_tuple
=
res_samples
)
# 6. post-process
sample
=
self
.
conv_norm_out
(
sample
)
sample
=
self
.
conv_act
(
sample
)
sample
=
self
.
conv_out
(
sample
)
output
=
{
"sample"
:
sample
}
return
output
# !!!IMPORTANT - ALL OF THE FOLLOWING CODE WILL BE DELETED AT RELEASE TIME AND SHOULD NOT BE TAKEN INTO CONSIDERATION WHEN EVALUATING THE API ###
# =================================================================================================================================================
def
set_weights
(
self
):
self
.
is_overwritten
=
True
if
self
.
ldm
:
self
.
time_embedding
.
linear_1
.
weight
.
data
=
self
.
time_embed
[
0
].
weight
.
data
self
.
time_embedding
.
linear_1
.
bias
.
data
=
self
.
time_embed
[
0
].
bias
.
data
self
.
time_embedding
.
linear_2
.
weight
.
data
=
self
.
time_embed
[
2
].
weight
.
data
self
.
time_embedding
.
linear_2
.
bias
.
data
=
self
.
time_embed
[
2
].
bias
.
data
self
.
conv_in
.
weight
.
data
=
self
.
input_blocks
[
0
][
0
].
weight
.
data
self
.
conv_in
.
bias
.
data
=
self
.
input_blocks
[
0
][
0
].
bias
.
data
# ================ SET WEIGHTS OF ALL WEIGHTS ==================
for
i
,
input_layer
in
enumerate
(
self
.
input_blocks
[
1
:]):
block_id
=
i
//
(
self
.
config
.
num_res_blocks
+
1
)
layer_in_block_id
=
i
%
(
self
.
config
.
num_res_blocks
+
1
)
if
layer_in_block_id
==
2
:
self
.
downsample_blocks
[
block_id
].
downsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
0
].
op
.
weight
.
data
self
.
downsample_blocks
[
block_id
].
downsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
0
].
op
.
bias
.
data
elif
len
(
input_layer
)
>
1
:
self
.
downsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
downsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
else
:
self
.
downsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
mid
.
resnets
[
0
].
set_weight
(
self
.
middle_block
[
0
])
self
.
mid
.
resnets
[
1
].
set_weight
(
self
.
middle_block
[
2
])
self
.
mid
.
attentions
[
0
].
set_weight
(
self
.
middle_block
[
1
])
for
i
,
input_layer
in
enumerate
(
self
.
output_blocks
):
block_id
=
i
//
(
self
.
config
.
num_res_blocks
+
1
)
layer_in_block_id
=
i
%
(
self
.
config
.
num_res_blocks
+
1
)
if
len
(
input_layer
)
>
2
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
2
].
conv
.
weight
.
data
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
2
].
conv
.
bias
.
data
elif
len
(
input_layer
)
>
1
and
"Upsample2D"
in
input_layer
[
1
].
__class__
.
__name__
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
weight
.
data
=
input_layer
[
1
].
conv
.
weight
.
data
self
.
upsample_blocks
[
block_id
].
upsamplers
[
0
].
conv
.
bias
.
data
=
input_layer
[
1
].
conv
.
bias
.
data
elif
len
(
input_layer
)
>
1
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
upsample_blocks
[
block_id
].
attentions
[
layer_in_block_id
].
set_weight
(
input_layer
[
1
])
else
:
self
.
upsample_blocks
[
block_id
].
resnets
[
layer_in_block_id
].
set_weight
(
input_layer
[
0
])
self
.
conv_norm_out
.
weight
.
data
=
self
.
out
[
0
].
weight
.
data
self
.
conv_norm_out
.
bias
.
data
=
self
.
out
[
0
].
bias
.
data
self
.
conv_out
.
weight
.
data
=
self
.
out
[
2
].
weight
.
data
self
.
conv_out
.
bias
.
data
=
self
.
out
[
2
].
bias
.
data
self
.
remove_ldm
()
def
init_for_ldm
(
self
,
in_channels
,
model_channels
,
channel_mult
,
num_res_blocks
,
dropout
,
time_embed_dim
,
attention_resolutions
,
num_head_channels
,
num_heads
,
legacy
,
use_spatial_transformer
,
transformer_depth
,
context_dim
,
conv_resample
,
out_channels
,
):
# TODO(PVP) - delete after weight conversion
class
TimestepEmbedSequential
(
nn
.
Sequential
):
"""
A sequential module that passes timestep embeddings to the children that support it as an extra input.
"""
pass
# TODO(PVP) - delete after weight conversion
def
conv_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if
dims
==
1
:
return
nn
.
Conv1d
(
*
args
,
**
kwargs
)
elif
dims
==
2
:
return
nn
.
Conv2d
(
*
args
,
**
kwargs
)
elif
dims
==
3
:
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
self
.
time_embed
=
nn
.
Sequential
(
nn
.
Linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
nn
.
Linear
(
time_embed_dim
,
time_embed_dim
),
)
dims
=
2
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
))]
)
self
.
_feature_size
=
model_channels
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
_
in
range
(
num_res_blocks
):
layers
=
[
ResnetBlock2D
(
in_channels
=
ch
,
out_channels
=
mult
*
model_channels
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
# num_heads = 1
dim_head
=
num_head_channels
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
),
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
Downsample2D
(
ch
,
use_conv
=
conv_resample
,
out_channels
=
out_ch
,
padding
=
1
,
name
=
"op"
)
)
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
ds
*=
2
self
.
_feature_size
+=
ch
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
# num_heads = 1
dim_head
=
num_head_channels
if
dim_head
<
0
:
dim_head
=
None
# TODO(Patrick) - delete after weight conversion
# init to be able to overwrite `self.mid`
self
.
middle_block
=
TimestepEmbedSequential
(
ResnetBlock2D
(
in_channels
=
ch
,
out_channels
=
None
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
),
ResnetBlock2D
(
in_channels
=
ch
,
out_channels
=
None
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
)
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
for
level
,
mult
in
list
(
enumerate
(
channel_mult
))[::
-
1
]:
for
i
in
range
(
num_res_blocks
+
1
):
ich
=
input_block_chans
.
pop
()
layers
=
[
ResnetBlock2D
(
in_channels
=
ch
+
ich
,
out_channels
=
model_channels
*
mult
,
dropout
=
dropout
,
temb_channels
=
time_embed_dim
,
eps
=
1e-5
,
non_linearity
=
"silu"
,
overwrite_for_ldm
=
True
,
),
]
ch
=
model_channels
*
mult
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
# num_heads = 1
dim_head
=
num_head_channels
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
)
)
if
level
and
i
==
num_res_blocks
:
out_ch
=
ch
layers
.
append
(
Upsample2D
(
ch
,
use_conv
=
conv_resample
,
out_channels
=
out_ch
))
ds
//=
2
self
.
output_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
_feature_size
+=
ch
self
.
out
=
nn
.
Sequential
(
nn
.
GroupNorm
(
num_channels
=
model_channels
,
num_groups
=
32
,
eps
=
1e-5
),
nn
.
SiLU
(),
nn
.
Conv2d
(
model_channels
,
out_channels
,
3
,
padding
=
1
),
)
def
remove_ldm
(
self
):
del
self
.
time_embed
del
self
.
input_blocks
del
self
.
middle_block
del
self
.
output_blocks
del
self
.
out
src/diffusers/models/unet_new.py
View file @
d5acb411
...
@@ -17,7 +17,7 @@ import numpy as np
...
@@ -17,7 +17,7 @@ import numpy as np
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
.attention
import
AttentionBlockNew
from
.attention
import
AttentionBlockNew
,
SpatialTransformer
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock
,
Upsample2D
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock
,
Upsample2D
...
@@ -56,6 +56,18 @@ def get_down_block(
...
@@ -56,6 +56,18 @@ def get_down_block(
downsample_padding
=
downsample_padding
,
downsample_padding
=
downsample_padding
,
attn_num_head_channels
=
attn_num_head_channels
,
attn_num_head_channels
=
attn_num_head_channels
,
)
)
elif
down_block_type
==
"UNetResCrossAttnDownBlock2D"
:
return
UNetResCrossAttnDownBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
downsample_padding
=
downsample_padding
,
attn_num_head_channels
=
attn_num_head_channels
,
)
elif
down_block_type
==
"UNetResSkipDownBlock2D"
:
elif
down_block_type
==
"UNetResSkipDownBlock2D"
:
return
UNetResSkipDownBlock2D
(
return
UNetResSkipDownBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
...
@@ -104,6 +116,18 @@ def get_up_block(
...
@@ -104,6 +116,18 @@ def get_up_block(
resnet_eps
=
resnet_eps
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
resnet_act_fn
=
resnet_act_fn
,
)
)
elif
up_block_type
==
"UNetResCrossAttnUpBlock2D"
:
return
UNetResCrossAttnUpBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
attn_num_head_channels
,
)
elif
up_block_type
==
"UNetResAttnUpBlock2D"
:
elif
up_block_type
==
"UNetResAttnUpBlock2D"
:
return
UNetResAttnUpBlock2D
(
return
UNetResAttnUpBlock2D
(
num_layers
=
num_layers
,
num_layers
=
num_layers
,
...
@@ -221,6 +245,83 @@ class UNetMidBlock2D(nn.Module):
...
@@ -221,6 +245,83 @@ class UNetMidBlock2D(nn.Module):
return
hidden_states
return
hidden_states
class
UNetMidBlock2DCrossAttn
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
cross_attention_dim
=
1280
,
**
kwargs
,
):
super
().
__init__
()
self
.
attention_type
=
attention_type
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
# there is always at least one resnet
resnets
=
[
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
]
attentions
=
[]
for
_
in
range
(
num_layers
):
attentions
.
append
(
SpatialTransformer
(
in_channels
,
attn_num_head_channels
,
in_channels
//
attn_num_head_channels
,
depth
=
1
,
context_dim
=
cross_attention_dim
,
)
)
resnets
.
append
(
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_hidden_states
=
None
):
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
hidden_states
=
attn
(
hidden_states
,
encoder_hidden_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
return
hidden_states
class
UNetResAttnDownBlock2D
(
nn
.
Module
):
class
UNetResAttnDownBlock2D
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -302,6 +403,88 @@ class UNetResAttnDownBlock2D(nn.Module):
...
@@ -302,6 +403,88 @@ class UNetResAttnDownBlock2D(nn.Module):
return
hidden_states
,
output_states
return
hidden_states
,
output_states
class
UNetResCrossAttnDownBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
cross_attention_dim
=
1280
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
downsample_padding
=
1
,
add_downsample
=
True
,
):
super
().
__init__
()
resnets
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
attentions
.
append
(
SpatialTransformer
(
out_channels
,
attn_num_head_channels
,
out_channels
//
attn_num_head_channels
,
depth
=
1
,
context_dim
=
cross_attention_dim
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
else
:
self
.
downsamplers
=
None
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_hidden_states
=
None
):
output_states
=
()
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
context
=
encoder_hidden_states
)
output_states
+=
(
hidden_states
,)
if
self
.
downsamplers
is
not
None
:
for
downsampler
in
self
.
downsamplers
:
hidden_states
=
downsampler
(
hidden_states
)
output_states
+=
(
hidden_states
,)
return
hidden_states
,
output_states
class
UNetResDownBlock2D
(
nn
.
Module
):
class
UNetResDownBlock2D
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -618,6 +801,86 @@ class UNetResAttnUpBlock2D(nn.Module):
...
@@ -618,6 +801,86 @@ class UNetResAttnUpBlock2D(nn.Module):
return
hidden_states
return
hidden_states
class
UNetResCrossAttnUpBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
prev_output_channel
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
cross_attention_dim
=
1280
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
downsample_padding
=
1
,
add_upsample
=
True
,
):
super
().
__init__
()
resnets
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
attentions
.
append
(
SpatialTransformer
(
out_channels
,
attn_num_head_channels
,
out_channels
//
attn_num_head_channels
,
depth
=
1
,
context_dim
=
cross_attention_dim
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
out_channels
,
use_conv
=
True
,
out_channels
=
out_channels
)])
else
:
self
.
upsamplers
=
None
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
encoder_hidden_states
=
None
):
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
context
=
encoder_hidden_states
)
if
self
.
upsamplers
is
not
None
:
for
upsampler
in
self
.
upsamplers
:
hidden_states
=
upsampler
(
hidden_states
)
return
hidden_states
class
UNetResUpBlock2D
(
nn
.
Module
):
class
UNetResUpBlock2D
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -765,8 +1028,6 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
...
@@ -765,8 +1028,6 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
self
.
act
=
None
self
.
act
=
None
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
skip_sample
=
None
):
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
skip_sample
=
None
):
output_states
=
()
for
resnet
in
self
.
resnets
:
for
resnet
in
self
.
resnets
:
# pop res hidden states
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
...
@@ -864,8 +1125,6 @@ class UNetResSkipUpBlock2D(nn.Module):
...
@@ -864,8 +1125,6 @@ class UNetResSkipUpBlock2D(nn.Module):
self
.
act
=
None
self
.
act
=
None
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
skip_sample
=
None
):
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
skip_sample
=
None
):
output_states
=
()
for
resnet
in
self
.
resnets
:
for
resnet
in
self
.
resnets
:
# pop res hidden states
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
...
...
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
d5acb411
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.utils.checkpoint
import
torch.utils.checkpoint
...
@@ -15,6 +14,69 @@ from transformers.utils import logging
...
@@ -15,6 +14,69 @@ from transformers.utils import logging
from
...pipeline_utils
import
DiffusionPipeline
from
...pipeline_utils
import
DiffusionPipeline
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
scheduler
):
super
().
__init__
()
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
,
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
# get unconditional embeddings for classifier free guidence
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)
# get text embedding
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
).
to
(
torch_device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
self
.
scheduler
.
timesteps
):
# 1. predict noise residual
pred_noise_t
=
self
.
unet
(
image
,
t
,
encoder_hidden_states
=
text_embedding
)
if
isinstance
(
pred_noise_t
,
dict
):
pred_noise_t
=
pred_noise_t
[
"sample"
]
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image
=
self
.
scheduler
.
step
(
pred_noise_t
,
t
,
image
,
eta
)[
"prev_sample"
]
# scale and decode image with vae
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
################################################################################
################################################################################
# Code for the text transformer model
# Code for the text transformer model
################################################################################
################################################################################
...
@@ -542,100 +604,3 @@ class LDMBertModel(LDMBertPreTrainedModel):
...
@@ -542,100 +604,3 @@ class LDMBertModel(LDMBertPreTrainedModel):
)
)
sequence_output
=
outputs
[
0
]
sequence_output
=
outputs
[
0
]
return
sequence_output
return
sequence_output
\ No newline at end of file
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
scheduler
):
super
().
__init__
()
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
,
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
# get unconditional embeddings for classifier free guidence
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)
# get text embedding
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)
num_trained_timesteps
=
self
.
scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
).
to
(
torch_device
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# guidance_scale of 1 means no guidance
if
guidance_scale
==
1.0
:
image_in
=
image
context
=
text_embedding
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
else
:
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
image_in
=
torch
.
cat
([
image
]
*
2
)
context
=
torch
.
cat
([
uncond_embeddings
,
text_embedding
])
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
# 1. predict noise residual
pred_noise_t
=
self
.
unet
(
image_in
,
timesteps
,
context
=
context
)
# perform guidance
if
guidance_scale
!=
1.0
:
pred_noise_t_uncond
,
pred_noise_t
=
pred_noise_t
.
chunk
(
2
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
# scale and decode image with vae
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
tests/test_modeling_utils.py
View file @
d5acb411
...
@@ -40,14 +40,17 @@ from diffusers import (
...
@@ -40,14 +40,17 @@ from diffusers import (
ScoreSdeVeScheduler
,
ScoreSdeVeScheduler
,
ScoreSdeVpPipeline
,
ScoreSdeVpPipeline
,
ScoreSdeVpScheduler
,
ScoreSdeVpScheduler
,
UNetConditionalModel
,
UNetLDMModel
,
UNetLDMModel
,
UNetUnconditionalModel
,
UNetUnconditionalModel
,
VQModel
,
VQModel
,
)
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion
import
LDMBertModel
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
from
transformers
import
BertTokenizer
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
...
@@ -827,7 +830,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -827,7 +830,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
class
Auto
E
ncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
Auto
e
ncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
AutoencoderKL
model_class
=
AutoencoderKL
@
property
@
property
...
@@ -1026,10 +1029,8 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -1026,10 +1029,8 @@ class PipelineTesterMixin(unittest.TestCase):
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
@
slow
@
unittest
.
skip
(
"Skipping for now as it takes too long"
)
def
test_ldm_text2img
(
self
):
def
test_ldm_text2img
(
self
):
model_id
=
"fusing/latent-diffusion-text2im-large"
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
"/home/patrick/latent-diffusion-text2im-large"
)
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
model_id
)
prompt
=
"A painting of a squirrel eating a burger"
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -1043,8 +1044,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -1043,8 +1044,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ldm_text2img_fast
(
self
):
def
test_ldm_text2img_fast
(
self
):
model_id
=
"fusing/latent-diffusion-text2im-large"
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
"/home/patrick/latent-diffusion-text2im-large"
)
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
model_id
)
prompt
=
"A painting of a squirrel eating a burger"
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -1074,6 +1074,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -1074,6 +1074,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_score_sde_ve_pipeline
(
self
):
def
test_score_sde_ve_pipeline
(
self
):
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ffhq_ncsnpp"
,
sde
=
True
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ffhq_ncsnpp"
,
sde
=
True
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"google/ffhq_ncsnpp"
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment