Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
d74b804d
"torchvision/csrc/ops/roi_align.cpp" did not exist on "7f7e7663e1bd5b8ae4e6f4a0b561319777769bdf"
Commit
d74b804d
authored
Jun 28, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
a859b199
22b63d15
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
642 additions
and
1 deletion
+642
-1
src/diffusers/__init__.py
src/diffusers/__init__.py
+9
-1
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-0
src/diffusers/pipelines/pipeline_latent_diffusion_uncond.py
src/diffusers/pipelines/pipeline_latent_diffusion_uncond.py
+632
-0
No files found.
src/diffusers/__init__.py
View file @
d74b804d
...
...
@@ -9,7 +9,15 @@ __version__ = "0.0.4"
from
.modeling_utils
import
ModelMixin
from
.models
import
NCSNpp
,
TemporalUNet
,
UNetLDMModel
,
UNetModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
PNDMPipeline
,
ScoreSdeVePipeline
,
ScoreSdeVpPipeline
from
.pipelines
import
(
BDDMPipeline
,
DDIMPipeline
,
DDPMPipeline
,
LatentDiffusionUncondPipeline
,
PNDMPipeline
,
ScoreSdeVePipeline
,
ScoreSdeVpPipeline
,
)
from
.schedulers
import
(
DDIMScheduler
,
DDPMScheduler
,
...
...
src/diffusers/pipelines/__init__.py
View file @
d74b804d
...
...
@@ -2,6 +2,7 @@ from ..utils import is_inflect_available, is_transformers_available, is_unidecod
from
.pipeline_bddm
import
BDDMPipeline
from
.pipeline_ddim
import
DDIMPipeline
from
.pipeline_ddpm
import
DDPMPipeline
from
.pipeline_latent_diffusion_uncond
import
LatentDiffusionUncondPipeline
from
.pipeline_pndm
import
PNDMPipeline
from
.pipeline_score_sde_ve
import
ScoreSdeVePipeline
from
.pipeline_score_sde_vp
import
ScoreSdeVpPipeline
...
...
src/diffusers/pipelines/pipeline_latent_diffusion_uncond.py
0 → 100644
View file @
d74b804d
import
math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
tqdm
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..pipeline_utils
import
DiffusionPipeline
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal
embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section
3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
temb_channels
>
0
:
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
if
temb
is
not
None
:
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
h_
=
self
.
proj_out
(
h_
)
return
x
+
h_
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
double_z
=
True
,
**
ignore_kwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
2
*
z_channels
if
double_z
else
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
# timestep embedding
temb
=
None
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
**
ignorekwargs
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
temb_ch
=
0
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
give_pre_end
=
give_pre_end
# compute in_ch_mult, block_in and curr_res at lowest res
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
print
(
"Working with z of shape {} = {} dimensions."
.
format
(
self
.
z_shape
,
np
.
prod
(
self
.
z_shape
)))
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
z
):
# assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
# timestep embedding
temb
=
None
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
if
self
.
give_pre_end
:
return
h
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
VectorQuantizer
(
nn
.
Module
):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def
__init__
(
self
,
n_e
,
e_dim
,
beta
,
remap
=
None
,
unknown_index
=
"random"
,
sane_index_shape
=
False
,
legacy
=
True
):
super
().
__init__
()
self
.
n_e
=
n_e
self
.
e_dim
=
e_dim
self
.
beta
=
beta
self
.
legacy
=
legacy
self
.
embedding
=
nn
.
Embedding
(
self
.
n_e
,
self
.
e_dim
)
self
.
embedding
.
weight
.
data
.
uniform_
(
-
1.0
/
self
.
n_e
,
1.0
/
self
.
n_e
)
self
.
remap
=
remap
if
self
.
remap
is
not
None
:
self
.
register_buffer
(
"used"
,
torch
.
tensor
(
np
.
load
(
self
.
remap
)))
self
.
re_embed
=
self
.
used
.
shape
[
0
]
self
.
unknown_index
=
unknown_index
# "random" or "extra" or integer
if
self
.
unknown_index
==
"extra"
:
self
.
unknown_index
=
self
.
re_embed
self
.
re_embed
=
self
.
re_embed
+
1
print
(
f
"Remapping
{
self
.
n_e
}
indices to
{
self
.
re_embed
}
indices. "
f
"Using
{
self
.
unknown_index
}
for unknown indices."
)
else
:
self
.
re_embed
=
n_e
self
.
sane_index_shape
=
sane_index_shape
def
remap_to_used
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
match
=
(
inds
[:,
:,
None
]
==
used
[
None
,
None
,
...]).
long
()
new
=
match
.
argmax
(
-
1
)
unknown
=
match
.
sum
(
2
)
<
1
if
self
.
unknown_index
==
"random"
:
new
[
unknown
]
=
torch
.
randint
(
0
,
self
.
re_embed
,
size
=
new
[
unknown
].
shape
).
to
(
device
=
new
.
device
)
else
:
new
[
unknown
]
=
self
.
unknown_index
return
new
.
reshape
(
ishape
)
def
unmap_to_all
(
self
,
inds
):
ishape
=
inds
.
shape
assert
len
(
ishape
)
>
1
inds
=
inds
.
reshape
(
ishape
[
0
],
-
1
)
used
=
self
.
used
.
to
(
inds
)
if
self
.
re_embed
>
self
.
used
.
shape
[
0
]:
# extra token
inds
[
inds
>=
self
.
used
.
shape
[
0
]]
=
0
# simply set to zero
back
=
torch
.
gather
(
used
[
None
,
:][
inds
.
shape
[
0
]
*
[
0
],
:],
1
,
inds
)
return
back
.
reshape
(
ishape
)
def
forward
(
self
,
z
):
# reshape z -> (batch, height, width, channel) and flatten
z
=
z
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
z_flattened
=
z
.
view
(
-
1
,
self
.
e_dim
)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d
=
(
torch
.
sum
(
z_flattened
**
2
,
dim
=
1
,
keepdim
=
True
)
+
torch
.
sum
(
self
.
embedding
.
weight
**
2
,
dim
=
1
)
-
2
*
torch
.
einsum
(
"bd,dn->bn"
,
z_flattened
,
self
.
embedding
.
weight
.
t
())
)
min_encoding_indices
=
torch
.
argmin
(
d
,
dim
=
1
)
z_q
=
self
.
embedding
(
min_encoding_indices
).
view
(
z
.
shape
)
perplexity
=
None
min_encodings
=
None
# compute loss for embedding
if
not
self
.
legacy
:
loss
=
self
.
beta
*
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
else
:
loss
=
torch
.
mean
((
z_q
.
detach
()
-
z
)
**
2
)
+
self
.
beta
*
torch
.
mean
((
z_q
-
z
.
detach
())
**
2
)
# preserve gradients
z_q
=
z
+
(
z_q
-
z
).
detach
()
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
if
self
.
remap
is
not
None
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z
.
shape
[
0
],
-
1
)
# add batch axis
min_encoding_indices
=
self
.
remap_to_used
(
min_encoding_indices
)
min_encoding_indices
=
min_encoding_indices
.
reshape
(
-
1
,
1
)
# flatten
if
self
.
sane_index_shape
:
min_encoding_indices
=
min_encoding_indices
.
reshape
(
z_q
.
shape
[
0
],
z_q
.
shape
[
2
],
z_q
.
shape
[
3
])
return
z_q
,
loss
,
(
perplexity
,
min_encodings
,
min_encoding_indices
)
def
get_codebook_entry
(
self
,
indices
,
shape
):
# shape specifying (batch, height, width, channel)
if
self
.
remap
is
not
None
:
indices
=
indices
.
reshape
(
shape
[
0
],
-
1
)
# add batch axis
indices
=
self
.
unmap_to_all
(
indices
)
indices
=
indices
.
reshape
(
-
1
)
# flatten again
# get quantized latent vectors
z_q
=
self
.
embedding
(
indices
)
if
shape
is
not
None
:
z_q
=
z_q
.
view
(
shape
)
# reshape back to match original input shape
z_q
=
z_q
.
permute
(
0
,
3
,
1
,
2
).
contiguous
()
return
z_q
class
VQModel
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
n_embed
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register_to_config
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
n_embed
=
n_embed
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
self
.
quantize
=
VectorQuantizer
(
n_embed
,
embed_dim
,
beta
=
0.25
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
return
h
def
decode
(
self
,
h
,
force_not_quantize
=
False
):
# also go through quantization layer
if
not
force_not_quantize
:
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
else
:
quant
=
h
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
class
LatentDiffusionUncondPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
unet
,
noise_scheduler
):
super
().
__init__
()
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
,
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
num_trained_timesteps
=
self
.
noise_scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
).
to
(
torch_device
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# 1. predict noise residual
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
pred_noise_t
=
self
.
unet
(
image
,
timesteps
)
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
# scale and decode image with vae
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment