Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d5acb411
Unverified
Commit
d5acb411
authored
Jul 19, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 19, 2022
Browse files
Finalize ldm (#96)
* upload * make checkpoint work * finalize
parent
6cabc599
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
999 additions
and
135 deletions
+999
-135
src/diffusers/__init__.py
src/diffusers/__init__.py
+9
-1
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+22
-24
src/diffusers/models/unet_conditional.py
src/diffusers/models/unet_conditional.py
+632
-0
src/diffusers/models/unet_new.py
src/diffusers/models/unet_new.py
+264
-5
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+64
-99
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+7
-6
No files found.
src/diffusers/__init__.py
View file @
d5acb411
...
...
@@ -7,7 +7,15 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__
=
"0.0.4"
from
.modeling_utils
import
ModelMixin
from
.models
import
AutoencoderKL
,
NCSNpp
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
from
.models
import
(
AutoencoderKL
,
NCSNpp
,
UNetConditionalModel
,
UNetLDMModel
,
UNetModel
,
UNetUnconditionalModel
,
VQModel
,
)
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
(
DDIMPipeline
,
...
...
src/diffusers/models/__init__.py
View file @
d5acb411
...
...
@@ -17,6 +17,7 @@
# limitations under the License.
from
.unet
import
UNetModel
from
.unet_conditional
import
UNetConditionalModel
from
.unet_glide
import
GlideSuperResUNetModel
,
GlideTextToImageUNetModel
,
GlideUNetModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_sde_score_estimation
import
NCSNpp
...
...
src/diffusers/models/attention.py
View file @
d5acb411
...
...
@@ -42,7 +42,7 @@ class AttentionBlockNew(nn.Module):
self
.
value
=
nn
.
Linear
(
channels
,
channels
)
self
.
rescale_output_factor
=
rescale_output_factor
self
.
proj_attn
=
zero_module
(
nn
.
Linear
(
channels
,
channels
,
1
)
)
self
.
proj_attn
=
nn
.
Linear
(
channels
,
channels
,
1
)
def
transpose_for_scores
(
self
,
projection
:
torch
.
Tensor
)
->
torch
.
Tensor
:
new_projection_shape
=
projection
.
size
()[:
-
1
]
+
(
self
.
num_heads
,
-
1
)
...
...
@@ -147,6 +147,8 @@ class SpatialTransformer(nn.Module):
def
__init__
(
self
,
in_channels
,
n_heads
,
d_head
,
depth
=
1
,
dropout
=
0.0
,
context_dim
=
None
):
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
d_head
=
d_head
self
.
in_channels
=
in_channels
inner_dim
=
n_heads
*
d_head
self
.
norm
=
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
...
...
@@ -160,7 +162,7 @@ class SpatialTransformer(nn.Module):
]
)
self
.
proj_out
=
zero_module
(
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
)
self
.
proj_out
=
nn
.
Conv2d
(
inner_dim
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
context
=
None
):
# note: if no context is given, cross-attention defaults to self-attention
...
...
@@ -175,6 +177,12 @@ class SpatialTransformer(nn.Module):
x
=
self
.
proj_out
(
x
)
return
x
+
x_in
def
set_weight
(
self
,
layer
):
self
.
norm
=
layer
.
norm
self
.
proj_in
=
layer
.
proj_in
self
.
transformer_blocks
=
layer
.
transformer_blocks
self
.
proj_out
=
layer
.
proj_out
class
BasicTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
n_heads
,
d_head
,
dropout
=
0.0
,
context_dim
=
None
,
gated_ff
=
True
,
checkpoint
=
True
):
...
...
@@ -270,14 +278,15 @@ class FeedForward(nn.Module):
return
self
.
net
(
x
)
# TODO(Patrick) - this can and should be removed
def
zero_module
(
module
):
"""
Zero out the parameters of a module and return it.
"""
for
p
in
module
.
parameters
():
p
.
detach
().
zero_
()
return
module
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
...
...
@@ -298,17 +307,6 @@ def default(val, d):
return
d
()
if
isfunction
(
d
)
else
d
# feedforward
class
GEGLU
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
proj
=
nn
.
Linear
(
dim_in
,
dim_out
*
2
)
def
forward
(
self
,
x
):
x
,
gate
=
self
.
proj
(
x
).
chunk
(
2
,
dim
=-
1
)
return
x
*
F
.
gelu
(
gate
)
# the main attention block that is used for all models
class
AttentionBlock
(
nn
.
Module
):
"""
...
...
@@ -348,7 +346,7 @@ class AttentionBlock(nn.Module):
if
encoder_channels
is
not
None
:
self
.
encoder_kv
=
nn
.
Conv1d
(
encoder_channels
,
channels
*
2
,
1
)
self
.
proj
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
)
)
self
.
proj
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
overwrite_qkv
=
overwrite_qkv
self
.
overwrite_linear
=
overwrite_linear
...
...
@@ -370,7 +368,7 @@ class AttentionBlock(nn.Module):
self
.
GroupNorm_0
=
nn
.
GroupNorm
(
num_groups
=
num_groups
,
num_channels
=
channels
,
eps
=
1e-6
)
else
:
self
.
proj_out
=
zero_module
(
nn
.
Conv1d
(
channels
,
channels
,
1
)
)
self
.
proj_out
=
nn
.
Conv1d
(
channels
,
channels
,
1
)
self
.
set_weights
(
self
)
self
.
is_overwritten
=
False
...
...
@@ -385,7 +383,7 @@ class AttentionBlock(nn.Module):
self
.
qkv
.
weight
.
data
=
qkv_weight
self
.
qkv
.
bias
.
data
=
qkv_bias
proj_out
=
zero_module
(
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
)
)
proj_out
=
nn
.
Conv1d
(
self
.
channels
,
self
.
channels
,
1
)
proj_out
.
weight
.
data
=
module
.
proj_out
.
weight
.
data
[:,
:,
:,
0
]
proj_out
.
bias
.
data
=
module
.
proj_out
.
bias
.
data
...
...
src/diffusers/models/unet_conditional.py
0 → 100644
View file @
d5acb411
This diff is collapsed.
Click to expand it.
src/diffusers/models/unet_new.py
View file @
d5acb411
...
...
@@ -17,7 +17,7 @@ import numpy as np
import
torch
from
torch
import
nn
from
.attention
import
AttentionBlockNew
from
.attention
import
AttentionBlockNew
,
SpatialTransformer
from
.resnet
import
Downsample2D
,
FirDownsample2D
,
FirUpsample2D
,
ResnetBlock
,
Upsample2D
...
...
@@ -56,6 +56,18 @@ def get_down_block(
downsample_padding
=
downsample_padding
,
attn_num_head_channels
=
attn_num_head_channels
,
)
elif
down_block_type
==
"UNetResCrossAttnDownBlock2D"
:
return
UNetResCrossAttnDownBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
add_downsample
=
add_downsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
downsample_padding
=
downsample_padding
,
attn_num_head_channels
=
attn_num_head_channels
,
)
elif
down_block_type
==
"UNetResSkipDownBlock2D"
:
return
UNetResSkipDownBlock2D
(
num_layers
=
num_layers
,
...
...
@@ -104,6 +116,18 @@ def get_up_block(
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
)
elif
up_block_type
==
"UNetResCrossAttnUpBlock2D"
:
return
UNetResCrossAttnUpBlock2D
(
num_layers
=
num_layers
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
prev_output_channel
=
prev_output_channel
,
temb_channels
=
temb_channels
,
add_upsample
=
add_upsample
,
resnet_eps
=
resnet_eps
,
resnet_act_fn
=
resnet_act_fn
,
attn_num_head_channels
=
attn_num_head_channels
,
)
elif
up_block_type
==
"UNetResAttnUpBlock2D"
:
return
UNetResAttnUpBlock2D
(
num_layers
=
num_layers
,
...
...
@@ -221,6 +245,83 @@ class UNetMidBlock2D(nn.Module):
return
hidden_states
class
UNetMidBlock2DCrossAttn
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
cross_attention_dim
=
1280
,
**
kwargs
,
):
super
().
__init__
()
self
.
attention_type
=
attention_type
resnet_groups
=
resnet_groups
if
resnet_groups
is
not
None
else
min
(
in_channels
//
4
,
32
)
# there is always at least one resnet
resnets
=
[
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
]
attentions
=
[]
for
_
in
range
(
num_layers
):
attentions
.
append
(
SpatialTransformer
(
in_channels
,
attn_num_head_channels
,
in_channels
//
attn_num_head_channels
,
depth
=
1
,
context_dim
=
cross_attention_dim
,
)
)
resnets
.
append
(
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_hidden_states
=
None
):
hidden_states
=
self
.
resnets
[
0
](
hidden_states
,
temb
)
for
attn
,
resnet
in
zip
(
self
.
attentions
,
self
.
resnets
[
1
:]):
hidden_states
=
attn
(
hidden_states
,
encoder_hidden_states
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
return
hidden_states
class
UNetResAttnDownBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -302,6 +403,88 @@ class UNetResAttnDownBlock2D(nn.Module):
return
hidden_states
,
output_states
class
UNetResCrossAttnDownBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
cross_attention_dim
=
1280
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
downsample_padding
=
1
,
add_downsample
=
True
,
):
super
().
__init__
()
resnets
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
in_channels
=
in_channels
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
attentions
.
append
(
SpatialTransformer
(
out_channels
,
attn_num_head_channels
,
out_channels
//
attn_num_head_channels
,
depth
=
1
,
context_dim
=
cross_attention_dim
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_downsample
:
self
.
downsamplers
=
nn
.
ModuleList
(
[
Downsample2D
(
in_channels
,
use_conv
=
True
,
out_channels
=
out_channels
,
padding
=
downsample_padding
,
name
=
"op"
)
]
)
else
:
self
.
downsamplers
=
None
def
forward
(
self
,
hidden_states
,
temb
=
None
,
encoder_hidden_states
=
None
):
output_states
=
()
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
context
=
encoder_hidden_states
)
output_states
+=
(
hidden_states
,)
if
self
.
downsamplers
is
not
None
:
for
downsampler
in
self
.
downsamplers
:
hidden_states
=
downsampler
(
hidden_states
)
output_states
+=
(
hidden_states
,)
return
hidden_states
,
output_states
class
UNetResDownBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -618,6 +801,86 @@ class UNetResAttnUpBlock2D(nn.Module):
return
hidden_states
class
UNetResCrossAttnUpBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
prev_output_channel
:
int
,
temb_channels
:
int
,
dropout
:
float
=
0.0
,
num_layers
:
int
=
1
,
resnet_eps
:
float
=
1e-6
,
resnet_time_scale_shift
:
str
=
"default"
,
resnet_act_fn
:
str
=
"swish"
,
resnet_groups
:
int
=
32
,
resnet_pre_norm
:
bool
=
True
,
attn_num_head_channels
=
1
,
cross_attention_dim
=
1280
,
attention_type
=
"default"
,
output_scale_factor
=
1.0
,
downsample_padding
=
1
,
add_upsample
=
True
,
):
super
().
__init__
()
resnets
=
[]
attentions
=
[]
self
.
attention_type
=
attention_type
for
i
in
range
(
num_layers
):
res_skip_channels
=
in_channels
if
(
i
==
num_layers
-
1
)
else
out_channels
resnet_in_channels
=
prev_output_channel
if
i
==
0
else
out_channels
resnets
.
append
(
ResnetBlock
(
in_channels
=
resnet_in_channels
+
res_skip_channels
,
out_channels
=
out_channels
,
temb_channels
=
temb_channels
,
eps
=
resnet_eps
,
groups
=
resnet_groups
,
dropout
=
dropout
,
time_embedding_norm
=
resnet_time_scale_shift
,
non_linearity
=
resnet_act_fn
,
output_scale_factor
=
output_scale_factor
,
pre_norm
=
resnet_pre_norm
,
)
)
attentions
.
append
(
SpatialTransformer
(
out_channels
,
attn_num_head_channels
,
out_channels
//
attn_num_head_channels
,
depth
=
1
,
context_dim
=
cross_attention_dim
,
)
)
self
.
attentions
=
nn
.
ModuleList
(
attentions
)
self
.
resnets
=
nn
.
ModuleList
(
resnets
)
if
add_upsample
:
self
.
upsamplers
=
nn
.
ModuleList
([
Upsample2D
(
out_channels
,
use_conv
=
True
,
out_channels
=
out_channels
)])
else
:
self
.
upsamplers
=
None
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
encoder_hidden_states
=
None
):
for
resnet
,
attn
in
zip
(
self
.
resnets
,
self
.
attentions
):
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
res_hidden_states_tuple
=
res_hidden_states_tuple
[:
-
1
]
hidden_states
=
torch
.
cat
([
hidden_states
,
res_hidden_states
],
dim
=
1
)
hidden_states
=
resnet
(
hidden_states
,
temb
)
hidden_states
=
attn
(
hidden_states
,
context
=
encoder_hidden_states
)
if
self
.
upsamplers
is
not
None
:
for
upsampler
in
self
.
upsamplers
:
hidden_states
=
upsampler
(
hidden_states
)
return
hidden_states
class
UNetResUpBlock2D
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -765,8 +1028,6 @@ class UNetResAttnSkipUpBlock2D(nn.Module):
self
.
act
=
None
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
skip_sample
=
None
):
output_states
=
()
for
resnet
in
self
.
resnets
:
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
...
...
@@ -864,8 +1125,6 @@ class UNetResSkipUpBlock2D(nn.Module):
self
.
act
=
None
def
forward
(
self
,
hidden_states
,
res_hidden_states_tuple
,
temb
=
None
,
skip_sample
=
None
):
output_states
=
()
for
resnet
in
self
.
resnets
:
# pop res hidden states
res_hidden_states
=
res_hidden_states_tuple
[
-
1
]
...
...
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
d5acb411
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
...
...
@@ -15,6 +14,69 @@ from transformers.utils import logging
from
...pipeline_utils
import
DiffusionPipeline
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
scheduler
):
super
().
__init__
()
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
,
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
# get unconditional embeddings for classifier free guidence
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)
# get text embedding
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
).
to
(
torch_device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
self
.
scheduler
.
timesteps
):
# 1. predict noise residual
pred_noise_t
=
self
.
unet
(
image
,
t
,
encoder_hidden_states
=
text_embedding
)
if
isinstance
(
pred_noise_t
,
dict
):
pred_noise_t
=
pred_noise_t
[
"sample"
]
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image
=
self
.
scheduler
.
step
(
pred_noise_t
,
t
,
image
,
eta
)[
"prev_sample"
]
# scale and decode image with vae
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
################################################################################
# Code for the text transformer model
################################################################################
...
...
@@ -541,101 +603,4 @@ class LDMBertModel(LDMBertPreTrainedModel):
return_dict
=
return_dict
,
)
sequence_output
=
outputs
[
0
]
return
sequence_output
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
scheduler
):
super
().
__init__
()
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
,
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
# get unconditional embeddings for classifier free guidence
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)
# get text embedding
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)
num_trained_timesteps
=
self
.
scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
).
to
(
torch_device
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# guidance_scale of 1 means no guidance
if
guidance_scale
==
1.0
:
image_in
=
image
context
=
text_embedding
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
else
:
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
image_in
=
torch
.
cat
([
image
]
*
2
)
context
=
torch
.
cat
([
uncond_embeddings
,
text_embedding
])
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
# 1. predict noise residual
pred_noise_t
=
self
.
unet
(
image_in
,
timesteps
,
context
=
context
)
# perform guidance
if
guidance_scale
!=
1.0
:
pred_noise_t_uncond
,
pred_noise_t
=
pred_noise_t
.
chunk
(
2
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
# scale and decode image with vae
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
return
sequence_output
\ No newline at end of file
tests/test_modeling_utils.py
View file @
d5acb411
...
...
@@ -40,14 +40,17 @@ from diffusers import (
ScoreSdeVeScheduler
,
ScoreSdeVpPipeline
,
ScoreSdeVpScheduler
,
UNetConditionalModel
,
UNetLDMModel
,
UNetUnconditionalModel
,
VQModel
,
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion
import
LDMBertModel
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
from
diffusers.training_utils
import
EMAModel
from
transformers
import
BertTokenizer
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
...
...
@@ -827,7 +830,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-2
))
class
Auto
E
ncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
class
Auto
e
ncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
AutoencoderKL
@
property
...
...
@@ -1026,10 +1029,8 @@ class PipelineTesterMixin(unittest.TestCase):
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
@
unittest
.
skip
(
"Skipping for now as it takes too long"
)
def
test_ldm_text2img
(
self
):
model_id
=
"fusing/latent-diffusion-text2im-large"
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
model_id
)
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
"/home/patrick/latent-diffusion-text2im-large"
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
...
...
@@ -1043,8 +1044,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
def
test_ldm_text2img_fast
(
self
):
model_id
=
"fusing/latent-diffusion-text2im-large"
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
model_id
)
ldm
=
LatentDiffusionPipeline
.
from_pretrained
(
"/home/patrick/latent-diffusion-text2im-large"
)
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
...
...
@@ -1074,6 +1074,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
def
test_score_sde_ve_pipeline
(
self
):
model
=
UNetUnconditionalModel
.
from_pretrained
(
"fusing/ffhq_ncsnpp"
,
sde
=
True
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"google/ffhq_ncsnpp"
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment