Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
eceeb972
Unverified
Commit
eceeb972
authored
Jun 29, 2022
by
Suraj Patil
Committed by
GitHub
Jun 29, 2022
Browse files
move the VAE models in src/models
move the VAE models in src/models
parents
814133ec
333a8da6
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
790 additions
and
1412 deletions
+790
-1412
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-1
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/vae.py
src/diffusers/models/vae.py
+636
-0
src/diffusers/pipelines/latent_diffusion/__init__.py
src/diffusers/pipelines/latent_diffusion/__init__.py
+1
-1
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+0
-849
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
...tent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+0
-561
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+151
-0
No files found.
src/diffusers/__init__.py
View file @
eceeb972
...
@@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
...
@@ -7,7 +7,7 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
__version__
=
"0.0.4"
__version__
=
"0.0.4"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.models
import
AutoencoderKL
,
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
,
VQModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
(
from
.pipelines
import
(
BDDMPipeline
,
BDDMPipeline
,
...
...
src/diffusers/models/__init__.py
View file @
eceeb972
...
@@ -22,3 +22,4 @@ from .unet_grad_tts import UNetGradTTSModel
...
@@ -22,3 +22,4 @@ from .unet_grad_tts import UNetGradTTSModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_ldm
import
UNetLDMModel
from
.unet_rl
import
TemporalUNet
from
.unet_rl
import
TemporalUNet
from
.unet_sde_score_estimation
import
NCSNpp
from
.unet_sde_score_estimation
import
NCSNpp
from
.vae
import
AutoencoderKL
,
VQModel
src/diffusers/models/vae.py
0 → 100644
View file @
eceeb972
import
math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
.attention
import
AttentionBlock
from
.resnet
import
Downsample
,
Upsample
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal
embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section
3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
double_z
=
True
,
**
ignore_kwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
use_conv
=
resamp_with_conv
,
padding
=
0
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
**
ignorekwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
give_pre_end
=
give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
print
(
"Working with z of shape {} = {} dimensions."
.
format
(
self
.
z_shape
,
np
.
prod
(
self
.
z_shape
)))
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttentionBlock
(
block_in
,
overwrite_qkv
=
True
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
use_conv
=
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
z
):
# assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
# timestep embedding
temb
=
None
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
if
self
.
give_pre_end
:
return
h
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
VectorQuantizer
(
nn
.
Module
):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def
__init__
(
self
,
n_e
,
e_dim
,
beta
,
remap
=
None
,
unknown_index
=
"random"
,
sane_index_shape
=
False
,
legacy
=
True
):
super
().
__init__
()
self
.
n_e
=
n_e
self
.
e_dim
=
e_dim
self
.
beta
=
beta
self
.
legacy
=
legacy
self
.
embedding
=
nn
.
Embedding
(
self
.
n_e
,
self
.
e_dim
)
self
.
embedding
.
weight
.
data
.
uniform_
(
-
1.0
/
self
.
n_e
,
1.0
/
self
.
n_e
)
self
.
remap
=
remap
if
self
.
remap
is
not
None
:
self
.
register_buffer
(
"used"
,
torch
.
tensor
(
np
.
load
(
self
.
remap
)))
self
.
re_embed
=
self
.
used
.
shape
[
0
]
self
.
unknown_index
=
unknown_index
# "random" or "extra" or integer
if
self
.
unknown_index
==
"extra"
:
self
.
unknown_index
=
self
.
re_embed
self
.
re_embed
=
self
.
re_embed
+
1
print
(
f
"Remapping
{
self
.
n_e
}
indices to
{
self
.
re_embed
}
indices. "
f
"Using
{
self
.
unknown_index
}
for unknown indices."
)
else
:
self
.
re_embed
=
n_e
self
.
sane_index_shape
=
sane_index_shape
def
remap_to_used
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
match
=
(
inds
[:,
:,
None
]
==
used
[
None
,
None
,
...]).
long
()
new
=
match
.
argmax
(
-
1
)
unknown
=
match
.
sum
(
2
)
<
1
if
self
.
unknown_index
==
"random"
:
new
[
unknown
]
=
torch
.
randint
(
0
,
self
.
re_embed
,
size
=
new
[
unknown
].
shape
).
to
(
device
=
new
.
device
)
else
:
new
[
unknown
]
=
self
.
unknown_index
return
new
.
reshape
(
ishape
)
def
unmap_to_all
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
if
self
.
re_embed
>
self
.
used
.
shape
[
0
]:
# extra token
inds
[
inds
>=
self
.
used
.
shape
[
0
]]
=
0
# simply set to zero
back
=
torch
.
gather
(
used
[
None
,
:][
inds
.
shape
[
0
]
*
[
0
],
:],
1
,
inds
)
return
back
.
reshape
(
ishape
)
def
forward
(
self
,
z
):
# reshape z -> (batch, height, width, channel) and flatten
z
=
z
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
z_flattened
=
z
.
view
(
-
1
,
self
.
e_dim
)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d
=
(
torch
.
sum
(
z_flattened
**
2
,
dim
=
1
,
keepdim
=
True
)
+
torch
.
sum
(
self
.
embedding
.
weight
**
2
,
dim
=
1
)
-
2
*
torch
.
einsum
(
"bd,dn->bn"
,
z_flattened
,
self
.
embedding
.
weight
.
t
())
)
min_encoding_indices
=
torch
.
argmin
(
d
,
dim
=
1
)
z_q
=
self
.
embedding
(
min_encoding_indices
).
view
(
z
.
shape
)
perplexity
=
None
min_encodings
=
None
# compute loss for embedding
if
not
self
.
legacy
:
loss
=
self
.
beta
*
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
else
:
loss
=
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
self
.
beta
*
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
# preserve gradients
z_q
=
z
+
(
z_q
-
z
).
detach
()
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
if
self
.
remap
is
not
None
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z
.
shape
[
0
],
-
1
)
# add batch axis
min_encoding_indices
=
self
.
remap_to_used
(
min_encoding_indices
)
min_encoding_indices
=
min_encoding_indices
.
reshape
(
-
1
,
1
)
# flatten
if
self
.
sane_index_shape
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z_q
.
shape
[
0
],
z_q
.
shape
[
2
],
z_q
.
shape
[
3
])
return
z_q
,
loss
,
(
perplexity
,
min_encodings
,
min_encoding_indices
)
def
get_codebook_entry
(
self
,
indices
,
shape
):
# shape specifying (batch, height, width, channel)
if
self
.
remap
is
not
None
:
indices
=
indices
.
reshape
(
shape
[
0
],
-
1
)
# add batch axis
indices
=
self
.
unmap_to_all
(
indices
)
indices
=
indices
.
reshape
(
-
1
)
# flatten again
# get quantized latent vectors
z_q
=
self
.
embedding
(
indices
)
if
shape
is
not
None
:
z_q
=
z_q
.
view
(
shape
)
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
return
z_q
class
DiagonalGaussianDistribution
(
object
):
def
__init__
(
self
,
parameters
,
deterministic
=
False
):
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=
1
)
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
).
to
(
device
=
self
.
parameters
.
device
)
def
sample
(
self
):
x
=
self
.
mean
+
self
.
std
*
torch
.
randn
(
self
.
mean
.
shape
).
to
(
device
=
self
.
parameters
.
device
)
return
x
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
def
mode
(
self
):
return
self
.
mean
class
VQModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
n_embed
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
n_embed
=
n_embed
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
z_channels
,
embed_dim
,
1
)
self
.
quantize
=
VectorQuantizer
(
n_embed
,
embed_dim
,
beta
=
0.25
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
return
h
def
decode
(
self
,
h
,
force_not_quantize
=
False
):
# also go through quantization layer
if
not
force_not_quantize
:
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
else
:
quant
=
h
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
def
forward
(
self
,
x
):
h
=
self
.
encode
(
x
)
dec
=
self
.
decode
(
h
)
return
dec
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
moments
=
self
.
quant_conv
(
h
)
posterior
=
DiagonalGaussianDistribution
(
moments
)
return
posterior
def
decode
(
self
,
z
):
z
=
self
.
post_quant_conv
(
z
)
dec
=
self
.
decoder
(
z
)
return
dec
def
forward
(
self
,
x
,
sample_posterior
=
False
):
posterior
=
self
.
encode
(
x
)
if
sample_posterior
:
z
=
posterior
.
sample
()
else
:
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
)
return
dec
src/diffusers/pipelines/latent_diffusion/__init__.py
View file @
eceeb972
...
@@ -2,4 +2,4 @@ from ...utils import is_transformers_available
...
@@ -2,4 +2,4 @@ from ...utils import is_transformers_available
if
is_transformers_available
():
if
is_transformers_available
():
from
.pipeline_latent_diffusion
import
AutoencoderKL
,
LatentDiffusionPipeline
,
LDMBertModel
from
.pipeline_latent_diffusion
import
LatentDiffusionPipeline
,
LDMBertModel
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
eceeb972
import
math
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -13,8 +12,6 @@ from transformers.modeling_outputs import BaseModelOutput
...
@@ -13,8 +12,6 @@ from transformers.modeling_outputs import BaseModelOutput
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
logging
from
transformers.utils
import
logging
from
...configuration_utils
import
ConfigMixin
from
...modeling_utils
import
ModelMixin
from
...pipeline_utils
import
DiffusionPipeline
from
...pipeline_utils
import
DiffusionPipeline
...
@@ -547,852 +544,6 @@ class LDMBertModel(LDMBertPreTrainedModel):
...
@@ -547,852 +544,6 @@ class LDMBertModel(LDMBertPreTrainedModel):
return
sequence_output
return
sequence_output
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal
embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section
3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
class
Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
use_timestep
=
True
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
self
.
ch
*
4
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
use_timestep
=
use_timestep
if
self
.
use_timestep
:
# timestep embedding
self
.
temb
=
nn
.
Module
()
self
.
temb
.
dense
=
nn
.
ModuleList
(
[
torch
.
nn
.
Linear
(
self
.
ch
,
self
.
temb_ch
),
torch
.
nn
.
Linear
(
self
.
temb_ch
,
self
.
temb_ch
),
]
)
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
skip_in
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
if
i_block
==
self
.
num_res_blocks
:
skip_in
=
ch
*
in_ch_mult
[
i_level
]
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
+
skip_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
,
t
=
None
):
# assert x.shape[2] == x.shape[3] == self.resolution
if
self
.
use_timestep
:
# timestep embedding
assert
t
is
not
None
temb
=
get_timestep_embedding
(
t
,
self
.
ch
)
temb
=
self
.
temb
.
dense
[
0
](
temb
)
temb
=
nonlinearity
(
temb
)
temb
=
self
.
temb
.
dense
[
1
](
temb
)
else
:
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
),
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
double_z
=
True
,
**
ignore_kwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
**
ignorekwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
give_pre_end
=
give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
print
(
"Working with z of shape {} = {} dimensions."
.
format
(
self
.
z_shape
,
np
.
prod
(
self
.
z_shape
)))
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
z
):
# assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
# timestep embedding
temb
=
None
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
if
self
.
give_pre_end
:
return
h
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
VectorQuantizer
(
nn
.
Module
):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def
__init__
(
self
,
n_e
,
e_dim
,
beta
,
remap
=
None
,
unknown_index
=
"random"
,
sane_index_shape
=
False
,
legacy
=
True
):
super
().
__init__
()
self
.
n_e
=
n_e
self
.
e_dim
=
e_dim
self
.
beta
=
beta
self
.
legacy
=
legacy
self
.
embedding
=
nn
.
Embedding
(
self
.
n_e
,
self
.
e_dim
)
self
.
embedding
.
weight
.
data
.
uniform_
(
-
1.0
/
self
.
n_e
,
1.0
/
self
.
n_e
)
self
.
remap
=
remap
if
self
.
remap
is
not
None
:
self
.
register_buffer
(
"used"
,
torch
.
tensor
(
np
.
load
(
self
.
remap
)))
self
.
re_embed
=
self
.
used
.
shape
[
0
]
self
.
unknown_index
=
unknown_index
# "random" or "extra" or integer
if
self
.
unknown_index
==
"extra"
:
self
.
unknown_index
=
self
.
re_embed
self
.
re_embed
=
self
.
re_embed
+
1
print
(
f
"Remapping
{
self
.
n_e
}
indices to
{
self
.
re_embed
}
indices. "
f
"Using
{
self
.
unknown_index
}
for unknown indices."
)
else
:
self
.
re_embed
=
n_e
self
.
sane_index_shape
=
sane_index_shape
def
remap_to_used
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
match
=
(
inds
[:,
:,
None
]
==
used
[
None
,
None
,
...]).
long
()
new
=
match
.
argmax
(
-
1
)
unknown
=
match
.
sum
(
2
)
<
1
if
self
.
unknown_index
==
"random"
:
new
[
unknown
]
=
torch
.
randint
(
0
,
self
.
re_embed
,
size
=
new
[
unknown
].
shape
).
to
(
device
=
new
.
device
)
else
:
new
[
unknown
]
=
self
.
unknown_index
return
new
.
reshape
(
ishape
)
def
unmap_to_all
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
if
self
.
re_embed
>
self
.
used
.
shape
[
0
]:
# extra token
inds
[
inds
>=
self
.
used
.
shape
[
0
]]
=
0
# simply set to zero
back
=
torch
.
gather
(
used
[
None
,
:][
inds
.
shape
[
0
]
*
[
0
],
:],
1
,
inds
)
return
back
.
reshape
(
ishape
)
def
forward
(
self
,
z
,
temp
=
None
,
rescale_logits
=
False
,
return_logits
=
False
):
assert
temp
is
None
or
temp
==
1.0
,
"Only for interface compatible with Gumbel"
assert
rescale_logits
==
False
,
"Only for interface compatible with Gumbel"
assert
return_logits
==
False
,
"Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z
=
rearrange
(
z
,
"b c h w -> b h w c"
).
contiguous
()
z_flattened
=
z
.
view
(
-
1
,
self
.
e_dim
)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d
=
(
torch
.
sum
(
z_flattened
**
2
,
dim
=
1
,
keepdim
=
True
)
+
torch
.
sum
(
self
.
embedding
.
weight
**
2
,
dim
=
1
)
-
2
*
torch
.
einsum
(
"bd,dn->bn"
,
z_flattened
,
rearrange
(
self
.
embedding
.
weight
,
"n d -> d n"
))
)
min_encoding_indices
=
torch
.
argmin
(
d
,
dim
=
1
)
z_q
=
self
.
embedding
(
min_encoding_indices
).
view
(
z
.
shape
)
perplexity
=
None
min_encodings
=
None
# compute loss for embedding
if
not
self
.
legacy
:
loss
=
self
.
beta
*
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
else
:
loss
=
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
self
.
beta
*
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
# preserve gradients
z_q
=
z
+
(
z_q
-
z
).
detach
()
# reshape back to match original input shape
z_q
=
rearrange
(
z_q
,
"b h w c -> b c h w"
).
contiguous
()
if
self
.
remap
is
not
None
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z
.
shape
[
0
],
-
1
)
# add batch axis
min_encoding_indices
=
self
.
remap_to_used
(
min_encoding_indices
)
min_encoding_indices
=
min_encoding_indices
.
reshape
(
-
1
,
1
)
# flatten
if
self
.
sane_index_shape
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z_q
.
shape
[
0
],
z_q
.
shape
[
2
],
z_q
.
shape
[
3
])
return
z_q
,
loss
,
(
perplexity
,
min_encodings
,
min_encoding_indices
)
def
get_codebook_entry
(
self
,
indices
,
shape
):
# shape specifying (batch, height, width, channel)
if
self
.
remap
is
not
None
:
indices
=
indices
.
reshape
(
shape
[
0
],
-
1
)
# add batch axis
indices
=
self
.
unmap_to_all
(
indices
)
indices
=
indices
.
reshape
(
-
1
)
# flatten again
# get quantized latent vectors
z_q
=
self
.
embedding
(
indices
)
if
shape
is
not
None
:
z_q
=
z_q
.
view
(
shape
)
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
return
z_q
class
VQModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
n_embed
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
n_embed
=
n_embed
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
self
.
quantize
=
VectorQuantizer
(
n_embed
,
embed_dim
,
beta
=
0.25
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
return
h
def
decode
(
self
,
h
,
force_not_quantize
=
False
):
# also go through quantization layer
if
not
force_not_quantize
:
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
else
:
quant
=
h
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
class
DiagonalGaussianDistribution
(
object
):
def
__init__
(
self
,
parameters
,
deterministic
=
False
):
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=
1
)
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
).
to
(
device
=
self
.
parameters
.
device
)
def
sample
(
self
):
x
=
self
.
mean
+
self
.
std
*
torch
.
randn
(
self
.
mean
.
shape
).
to
(
device
=
self
.
parameters
.
device
)
return
x
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.0
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
def
mode
(
self
):
return
self
.
mean
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
moments
=
self
.
quant_conv
(
h
)
posterior
=
DiagonalGaussianDistribution
(
moments
)
return
posterior
def
decode
(
self
,
z
):
z
=
self
.
post_quant_conv
(
z
)
dec
=
self
.
decoder
(
z
)
return
dec
def
forward
(
self
,
input
,
sample_posterior
=
True
):
posterior
=
self
.
encode
(
input
)
if
sample_posterior
:
z
=
posterior
.
sample
()
else
:
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
)
return
dec
,
posterior
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
noise_scheduler
):
super
().
__init__
()
super
().
__init__
()
...
...
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
View file @
eceeb972
import
math
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
tqdm
import
tqdm
from
...configuration_utils
import
ConfigMixin
from
...modeling_utils
import
ModelMixin
from
...pipeline_utils
import
DiffusionPipeline
from
...pipeline_utils
import
DiffusionPipeline
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal
embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section
3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
double_z
=
True
,
**
ignore_kwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
**
ignorekwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
give_pre_end
=
give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
print
(
"Working with z of shape {} = {} dimensions."
.
format
(
self
.
z_shape
,
np
.
prod
(
self
.
z_shape
)))
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
z
):
# assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
# timestep embedding
temb
=
None
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
if
self
.
give_pre_end
:
return
h
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
VectorQuantizer
(
nn
.
Module
):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def
__init__
(
self
,
n_e
,
e_dim
,
beta
,
remap
=
None
,
unknown_index
=
"random"
,
sane_index_shape
=
False
,
legacy
=
True
):
super
().
__init__
()
self
.
n_e
=
n_e
self
.
e_dim
=
e_dim
self
.
beta
=
beta
self
.
legacy
=
legacy
self
.
embedding
=
nn
.
Embedding
(
self
.
n_e
,
self
.
e_dim
)
self
.
embedding
.
weight
.
data
.
uniform_
(
-
1.0
/
self
.
n_e
,
1.0
/
self
.
n_e
)
self
.
remap
=
remap
if
self
.
remap
is
not
None
:
self
.
register_buffer
(
"used"
,
torch
.
tensor
(
np
.
load
(
self
.
remap
)))
self
.
re_embed
=
self
.
used
.
shape
[
0
]
self
.
unknown_index
=
unknown_index
# "random" or "extra" or integer
if
self
.
unknown_index
==
"extra"
:
self
.
unknown_index
=
self
.
re_embed
self
.
re_embed
=
self
.
re_embed
+
1
print
(
f
"Remapping
{
self
.
n_e
}
indices to
{
self
.
re_embed
}
indices. "
f
"Using
{
self
.
unknown_index
}
for unknown indices."
)
else
:
self
.
re_embed
=
n_e
self
.
sane_index_shape
=
sane_index_shape
def
remap_to_used
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
match
=
(
inds
[:,
:,
None
]
==
used
[
None
,
None
,
...]).
long
()
new
=
match
.
argmax
(
-
1
)
unknown
=
match
.
sum
(
2
)
<
1
if
self
.
unknown_index
==
"random"
:
new
[
unknown
]
=
torch
.
randint
(
0
,
self
.
re_embed
,
size
=
new
[
unknown
].
shape
).
to
(
device
=
new
.
device
)
else
:
new
[
unknown
]
=
self
.
unknown_index
return
new
.
reshape
(
ishape
)
def
unmap_to_all
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
if
self
.
re_embed
>
self
.
used
.
shape
[
0
]:
# extra token
inds
[
inds
>=
self
.
used
.
shape
[
0
]]
=
0
# simply set to zero
back
=
torch
.
gather
(
used
[
None
,
:][
inds
.
shape
[
0
]
*
[
0
],
:],
1
,
inds
)
return
back
.
reshape
(
ishape
)
def
forward
(
self
,
z
):
# reshape z -> (batch, height, width, channel) and flatten
z
=
z
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
z_flattened
=
z
.
view
(
-
1
,
self
.
e_dim
)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d
=
(
torch
.
sum
(
z_flattened
**
2
,
dim
=
1
,
keepdim
=
True
)
+
torch
.
sum
(
self
.
embedding
.
weight
**
2
,
dim
=
1
)
-
2
*
torch
.
einsum
(
"bd,dn->bn"
,
z_flattened
,
self
.
embedding
.
weight
.
t
())
)
min_encoding_indices
=
torch
.
argmin
(
d
,
dim
=
1
)
z_q
=
self
.
embedding
(
min_encoding_indices
).
view
(
z
.
shape
)
perplexity
=
None
min_encodings
=
None
# compute loss for embedding
if
not
self
.
legacy
:
loss
=
self
.
beta
*
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
else
:
loss
=
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
self
.
beta
*
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
# preserve gradients
z_q
=
z
+
(
z_q
-
z
).
detach
()
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
if
self
.
remap
is
not
None
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z
.
shape
[
0
],
-
1
)
# add batch axis
min_encoding_indices
=
self
.
remap_to_used
(
min_encoding_indices
)
min_encoding_indices
=
min_encoding_indices
.
reshape
(
-
1
,
1
)
# flatten
if
self
.
sane_index_shape
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z_q
.
shape
[
0
],
z_q
.
shape
[
2
],
z_q
.
shape
[
3
])
return
z_q
,
loss
,
(
perplexity
,
min_encodings
,
min_encoding_indices
)
def
get_codebook_entry
(
self
,
indices
,
shape
):
# shape specifying (batch, height, width, channel)
if
self
.
remap
is
not
None
:
indices
=
indices
.
reshape
(
shape
[
0
],
-
1
)
# add batch axis
indices
=
self
.
unmap_to_all
(
indices
)
indices
=
indices
.
reshape
(
-
1
)
# flatten again
# get quantized latent vectors
z_q
=
self
.
embedding
(
indices
)
if
shape
is
not
None
:
z_q
=
z_q
.
view
(
shape
)
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
return
z_q
class
VQModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
n_embed
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
n_embed
=
n_embed
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
self
.
quantize
=
VectorQuantizer
(
n_embed
,
embed_dim
,
beta
=
0.25
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
return
h
def
decode
(
self
,
h
,
force_not_quantize
=
False
):
# also go through quantization layer
if
not
force_not_quantize
:
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
else
:
quant
=
h
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
class
LatentDiffusionUncondPipeline
(
DiffusionPipeline
):
class
LatentDiffusionUncondPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
vqvae
,
unet
,
noise_scheduler
):
super
().
__init__
()
super
().
__init__
()
...
...
tests/test_modeling_utils.py
View file @
eceeb972
...
@@ -22,6 +22,7 @@ import numpy as np
...
@@ -22,6 +22,7 @@ import numpy as np
import
torch
import
torch
from
diffusers
import
(
from
diffusers
import
(
AutoencoderKL
,
BDDMPipeline
,
BDDMPipeline
,
DDIMPipeline
,
DDIMPipeline
,
DDIMScheduler
,
DDIMScheduler
,
...
@@ -44,6 +45,8 @@ from diffusers import (
...
@@ -44,6 +45,8 @@ from diffusers import (
UNetGradTTSModel
,
UNetGradTTSModel
,
UNetLDMModel
,
UNetLDMModel
,
UNetModel
,
UNetModel
,
VQModel
,
AutoencoderKL
,
)
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
...
@@ -805,6 +808,154 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -805,6 +808,154 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
VQModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
VQModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
3
sizes
=
(
32
,
32
)
image
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
return
{
"x"
:
image
}
@
property
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
output_shape
(
self
):
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"ch"
:
64
,
"out_ch"
:
3
,
"num_res_blocks"
:
1
,
"attn_resolutions"
:
[],
"in_channels"
:
3
,
"resolution"
:
32
,
"z_channels"
:
3
,
"n_embed"
:
256
,
"embed_dim"
:
3
,
"sane_index_shape"
:
False
,
"ch_mult"
:
(
1
,),
"dropout"
:
0.0
,
"double_z"
:
False
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_forward_signature
(
self
):
pass
def
test_training
(
self
):
pass
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
VQModel
.
from_pretrained
(
"fusing/vqgan-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
VQModel
.
from_pretrained
(
"fusing/vqgan-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
image
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
with
torch
.
no_grad
():
output
=
model
(
image
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
1.1321
,
0.1056
,
0.3505
,
-
0.6461
,
-
0.2014
,
0.0419
,
-
0.5763
,
-
0.8462
,
-
0.4218
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
AutoEncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
AutoencoderKL
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
3
sizes
=
(
32
,
32
)
image
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
return
{
"x"
:
image
}
@
property
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
output_shape
(
self
):
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"ch"
:
64
,
"ch_mult"
:
(
1
,),
"embed_dim"
:
4
,
"in_channels"
:
3
,
"num_res_blocks"
:
1
,
"out_ch"
:
3
,
"resolution"
:
32
,
"z_channels"
:
4
,
"attn_resolutions"
:
[]
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_forward_signature
(
self
):
pass
def
test_training
(
self
):
pass
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
AutoencoderKL
.
from_pretrained
(
"fusing/autoencoder-kl-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
AutoencoderKL
.
from_pretrained
(
"fusing/autoencoder-kl-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
image
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
with
torch
.
no_grad
():
output
=
model
(
image
,
sample_posterior
=
True
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.0814
,
-
0.0229
,
-
0.1320
,
-
0.4123
,
-
0.0366
,
-
0.3473
,
0.0438
,
-
0.1662
,
0.1750
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
PipelineTesterMixin
(
unittest
.
TestCase
):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
# 1. Load models
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment