Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
d5acb411
Unverified
Commit
d5acb411
authored
Jul 19, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 19, 2022
Browse files
Finalize ldm (#96)
* upload * make checkpoint work * finalize
parent
6cabc599
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
999 additions
and
135 deletions
+999
-135
src/diffusers/__init__.py
src/diffusers/__init__.py
+9
-1
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+22
-24
src/diffusers/models/unet_conditional.py
src/diffusers/models/unet_conditional.py
+632
-0
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+264
-5
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+64
-99
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+7
-6
No files found.
src/diffusers/__init__.py
View file @
d5acb411
...
...
@@ -7,7 +7,15 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__
=
"0.0.4"
from
.modeling_utils
import
ModelMixin
from
.models
import
AutoencoderKL
,
NCSNpp
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
from
.models
import
(
AutoencoderKL
,
NCSNpp
,
UNetConditionalModel
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
,
)
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
(
DDIMPipeline
,
...
...
src/diffusers/models/__init__.py
View file @
d5acb411
...
...
@@ -17,6 +17,7 @@
# limitations under the License.
from
.unet
import
UNetModel
from
.unet_conditional
import
UNetConditionalModel
from
.unet_glide
import
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GlideUNetModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_sde_score_estimation
import
NCSNpp
...
...
src/diffusers/models/attention.py
View file @
d5acb411
...
...
@@ -42,7 +42,7 @@ class AttentionBlockNew(nn.Module):
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
)
)
self
.
proj_attn
=
nn
.
Linear
(
channels
,
channels
,
1
)
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
-
1
)
...
...
@@ -147,6 +147,8 @@ class SpatialTransformer(nn.Module):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
d_head
=
d_head
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
...
...
@@ -160,7 +162,7 @@ class SpatialTransformer(nn.Module):
]
)
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
)
self
.
proj_out
=
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
context
=
None
):
# note: if no context is given, cross-attention defaults to self-attention
...
...
@@ -175,6 +177,12 @@ class SpatialTransformer(nn.Module):
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
def
set_weight
(
self
,
layer
):
self
.
norm
=
layer
.
norm
self
.
proj_in
=
layer
.
proj_in
self
.
transformer_blocks
=
layer
.
transformer_blocks
self
.
proj_out
=
layer
.
proj_out
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
...
...
@@ -270,14 +278,15 @@ class FeedForward(nn.Module):
return
self
.
net
(
x
)
# TODO(Patrick) - this can and should be removed
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
...
...
@@ -298,17 +307,6 @@ def default(val, d):
return
d
()
if
isfunction
(
d
)
else
d
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
# the main attention block that is used for all models
class
AttentionBlock
(
nn
.
Module
):
"""
...
...
@@ -348,7 +346,7 @@ class AttentionBlock(nn.Module):
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
)
)
self
.
proj
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
overwrite_qkv
=
overwrite_qkv
self
.
overwrite_linear
=
overwrite_linear
...
...
@@ -370,7 +368,7 @@ class AttentionBlock(nn.Module):
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
channels
,
eps
=
1e-6
)
else
:
self
.
proj_out
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
)
)
self
.
proj_out
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
set_weights
(
self
)
self
.
is_overwritten
=
False
...
...
@@ -385,7 +383,7 @@ class AttentionBlock(nn.Module):
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
zero_module
(
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
)
)
proj_out
=
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
)
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
...
...
src/diffusers/models/unet_conditional.py
0 → 100644
View file @
d5acb411
This diff is collapsed.
Click to expand it.
src/diffusers/models/unet_new.py
View file @
d5acb411
...
...
@@ -17,7 +17,7 @@ import numpy as np
import
torch
from
torch
import
nn
from
.attention
import
AttentionBlockNew
from
.attention
import
AttentionBlockNew
,
SpatialTransformer
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock
,
Upsample2D
...
...
@@ -56,6 +56,18 @@ def get_down_block(
downsample_padding
=
downsample_padding
,
attn_num_head_channels
=
attn_num_head_channels
,
)
elif
down_block_type
==
"UNetResCrossAttnDownBlock2D"
:
return
UNetResCrossAttnDownBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
downsample_padding
=
downsample_padding
,
attn_num_head_channels
=
attn_num_head_channels
,
)
elif
down_block_type
==
"UNetResSkipDownBlock2D"
:
return
UNetResSkipDownBlock2D
(
num_layers
=
num_layers
,
...
...
@@ -104,6 +116,18 @@ def get_up_block(
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
)
elif
up_block_type
==
"UNetResCrossAttnUpBlock2D"
:
return
UNetResCrossAttnUpBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
attn_num_head_channels
,
)
elif
up_block_type
==
"UNetResAttnUpBlock2D"
:
return
UNetResAttnUpBlock2D
(
num_layers
=
num_layers
,
...
...
@@ -221,6 +245,83 @@ class UNetMidBlock2D(nn.Module):
return
hidden_states
class
UNetMidBlock2DCrossAttn
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
cross_attention_dim
=
1280
,
**
kwargs
,
):
super
().
__init__
()
self
.
attention_type
=
attention_type
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
# there is always at least one resnet
resnets
=
[
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
]
attentions
=
[]
for
_
in
range
(
num_layers
):
attentions
.
append
(
SpatialTransformer
(
in_channels
,
attn_num_head_channels
,
in_channels
//
attn_num_head_channels
,
depth
=
1
,
context_dim
=
cross_attention_dim
,
)
)
resnets
.
append
(
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_hidden_states
=
None
):
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
hidden_states
=
attn
(
hidden_states
,
encoder_hidden_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
return
hidden_states
class
UNetResAttnDownBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -302,6 +403,88 @@ class UNetResAttnDownBlock2D(nn.Module):
return
hidden_states
,
output_states
class
UNetResCrossAttnDownBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
cross_attention_dim
=
1280
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
downsample_padding
=
1
,
add_downsample
=
True
,
):
super
().
__init__
()
resnets
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
attentions
.
append
(
SpatialTransformer
(
out_channels
,
attn_num_head_channels
,
out_channels
//
attn_num_head_channels
,
depth
=
1
,
context_dim
=
cross_attention_dim
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
else
:
self
.
downsamplers
=
None
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_hidden_states
=
None
):
output_states
=
()
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
context
=
encoder_hidden_states
)
output_states
+=
(
hidden_states
,)
if
self
.
downsamplers
is
not
None
:
for
downsampler
in
self
.
downsamplers
:
hidden_states
=
downsampler
(
hidden_states
)
output_states
+=
(
hidden_states
,)
return
hidden_states
,
output_states
class
UNetResDownBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -618,6 +801,86 @@ class UNetResAttnUpBlock2D(nn.Module):
return
hidden_states
class
UNetResCrossAttnUpBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
prev_output_channel
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
cross_attention_dim
=
1280
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
downsample_padding
=
1
,
add_upsample
=
True
,
):
super
().
__init__
()
resnets
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
attentions
.
append
(
SpatialTransformer
(
out_channels
,
attn_num_head_channels
,
out_channels
//
attn_num_head_channels
,
depth
=
1
,
context_dim
=
cross_attention_dim
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
out_channels
,
use_conv
=
True
,
out_channels
=
out_channels
)])
else
:
self
.
upsamplers
=
None
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
encoder_hidden_states
=
None
):
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
context
=
encoder_hidden_states
)
if
self
.
upsamplers
is
not
None
:
for
upsampler
in
self
.
upsamplers
:
hidden_states
=
upsampler
(
hidden_states
)
return
hidden_states
class
UNetResUpBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -765,8 +1028,6 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
self
.
act
=
None
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
skip_sample
=
None
):
output_states
=
()
for
resnet
in
self
.
resnets
:
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
...
...
@@ -864,8 +1125,6 @@ class UNetResSkipUpBlock2D(nn.Module):
self
.
act
=
None
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
skip_sample
=
None
):
output_states
=
()
for
resnet
in
self
.
resnets
:
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
...
...
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
d5acb411
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
...
...
@@ -15,6 +14,69 @@ from transformers.utils import logging
from
...pipeline_utils
import
DiffusionPipeline
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
scheduler
):
super
().
__init__
()
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
,
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
# get unconditional embeddings for classifier free guidence
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)
# get text embedding
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
).
to
(
torch_device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
self
.
scheduler
.
timesteps
):
# 1. predict noise residual
pred_noise_t
=
self
.
unet
(
image
,
t
,
encoder_hidden_states
=
text_embedding
)
if
isinstance
(
pred_noise_t
,
dict
):
pred_noise_t
=
pred_noise_t
[
"sample"
]
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image
=
self
.
scheduler
.
step
(
pred_noise_t
,
t
,
image
,
eta
)[
"prev_sample"
]
# scale and decode image with vae
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
################################################################################
# Code for the text transformer model
################################################################################
...
...
@@ -542,100 +604,3 @@ class LDMBertModel(LDMBertPreTrainedModel):
)
sequence_output
=
outputs
[
0
]
return
sequence_output
\ No newline at end of file
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
scheduler
):
super
().
__init__
()
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
,
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
# get unconditional embeddings for classifier free guidence
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)
# get text embedding
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)
num_trained_timesteps
=
self
.
scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
).
to
(
torch_device
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# guidance_scale of 1 means no guidance
if
guidance_scale
==
1.0
:
image_in
=
image
context
=
text_embedding
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
else
:
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
image_in
=
torch
.
cat
([
image
]
*
2
)
context
=
torch
.
cat
([
uncond_embeddings
,
text_embedding
])
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
# 1. predict noise residual
pred_noise_t
=
self
.
unet
(
image_in
,
timesteps
,
context
=
context
)
# perform guidance
if
guidance_scale
!=
1.0
:
pred_noise_t_uncond
,
pred_noise_t
=
pred_noise_t
.
chunk
(
2
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
# scale and decode image with vae
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
tests/test_modeling_utils.py
View file @
d5acb411
...
...
@@ -40,14 +40,17 @@ from diffusers import (
ScoreSdeVeScheduler
,
ScoreSdeVpPipeline
,
ScoreSdeVpScheduler
,
UNetConditionalModel
,
UNetLDMModel
,
UNetUnconditionalModel
,
VQModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion
import
LDMBertModel
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.training_utils
import
EMAModel
from
transformers
import
BertTokenizer
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
...
...
@@ -827,7 +830,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
class
Auto
E
ncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
Auto
e
ncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
AutoencoderKL
@
property
...
...
@@ -1026,10 +1029,8 @@ class PipelineTesterMixin(unittest.TestCase):
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
@
unittest
.
skip
(
"Skipping for now as it takes too long"
)
def
test_ldm_text2img
(
self
):
model_id
=
"fusing/latent-diffusion-text2im-large"
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
model_id
)
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
"/home/patrick/latent-diffusion-text2im-large"
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
...
...
@@ -1043,8 +1044,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
def
test_ldm_text2img_fast
(
self
):
model_id
=
"fusing/latent-diffusion-text2im-large"
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
model_id
)
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
"/home/patrick/latent-diffusion-text2im-large"
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
...
...
@@ -1074,6 +1074,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
def
test_score_sde_ve_pipeline
(
self
):
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ffhq_ncsnpp"
,
sde
=
True
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"google/ffhq_ncsnpp"
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment