Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8d649443
Commit
8d649443
authored
Jun 09, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
parents
999d3856
e3820fa3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
216 additions
and
0 deletions
+216
-0
models/vision/latent_diffusion/modeling_latent_diffusion.py
models/vision/latent_diffusion/modeling_latent_diffusion.py
+214
-0
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+2
-0
No files found.
models/vision/latent_diffusion/modeling_latent_diffusion.py
View file @
8d649443
...
...
@@ -2,9 +2,11 @@
import
math
import
numpy
as
np
import
tqdm
import
torch
import
torch.nn
as
nn
from
diffusers
import
DiffusionPipeline
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.modeling_utils
import
ModelMixin
...
...
@@ -719,3 +721,215 @@ class VQModel(ModelMixin, ConfigMixin):
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
class
DiagonalGaussianDistribution
(
object
):
def
__init__
(
self
,
parameters
,
deterministic
=
False
):
self
.
parameters
=
parameters
self
.
mean
,
self
.
logvar
=
torch
.
chunk
(
parameters
,
2
,
dim
=
1
)
self
.
logvar
=
torch
.
clamp
(
self
.
logvar
,
-
30.0
,
20.0
)
self
.
deterministic
=
deterministic
self
.
std
=
torch
.
exp
(
0.5
*
self
.
logvar
)
self
.
var
=
torch
.
exp
(
self
.
logvar
)
if
self
.
deterministic
:
self
.
var
=
self
.
std
=
torch
.
zeros_like
(
self
.
mean
).
to
(
device
=
self
.
parameters
.
device
)
def
sample
(
self
):
x
=
self
.
mean
+
self
.
std
*
torch
.
randn
(
self
.
mean
.
shape
).
to
(
device
=
self
.
parameters
.
device
)
return
x
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
])
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
def
mode
(
self
):
return
self
.
mean
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
ch
,
out_ch
,
num_res_blocks
,
attn_resolutions
,
in_channels
,
resolution
,
z_channels
,
n_embed
,
embed_dim
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
ch_mult
=
(
1
,
2
,
4
,
8
),
dropout
=
0.0
,
double_z
=
True
,
resamp_with_conv
=
True
,
give_pre_end
=
False
,
):
super
().
__init__
()
# register all __init__ params with self.register
self
.
register
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
n_embed
=
n_embed
,
embed_dim
=
embed_dim
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
double_z
=
double_z
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Encoder
self
.
encoder
=
Encoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
double_z
=
double_z
,
give_pre_end
=
give_pre_end
,
)
# pass init params to Decoder
self
.
decoder
=
Decoder
(
ch
=
ch
,
out_ch
=
out_ch
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
in_channels
=
in_channels
,
resolution
=
resolution
,
z_channels
=
z_channels
,
ch_mult
=
ch_mult
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
give_pre_end
=
give_pre_end
,
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
moments
=
self
.
quant_conv
(
h
)
posterior
=
DiagonalGaussianDistribution
(
moments
)
return
posterior
def
decode
(
self
,
z
):
z
=
self
.
post_quant_conv
(
z
)
dec
=
self
.
decoder
(
z
)
return
dec
def
forward
(
self
,
input
,
sample_posterior
=
True
):
posterior
=
self
.
encode
(
input
)
if
sample_posterior
:
z
=
posterior
.
sample
()
else
:
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
)
return
dec
,
posterior
class
LatentDiffusion
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
noise_scheduler
):
super
().
__init__
()
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
# get text embedding
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
**
text_input
)[
0
]
num_trained_timesteps
=
self
.
noise_scheduler
.
num_timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
self
.
noise_scheduler
.
sample_noise
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
//
8
,
self
.
unet
.
resolution
//
8
),
device
=
torch_device
,
generator
=
generator
,
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# get actual t and t-1
train_step
=
inference_step_times
[
t
]
prev_train_step
=
inference_step_times
[
t
-
1
]
if
t
>
0
else
-
1
# compute alphas
alpha_prod_t
=
self
.
noise_scheduler
.
get_alpha_prod
(
train_step
)
alpha_prod_t_prev
=
self
.
noise_scheduler
.
get_alpha_prod
(
prev_train_step
)
alpha_prod_t_rsqrt
=
1
/
alpha_prod_t
.
sqrt
()
alpha_prod_t_prev_rsqrt
=
1
/
alpha_prod_t_prev
.
sqrt
()
beta_prod_t_sqrt
=
(
1
-
alpha_prod_t
).
sqrt
()
beta_prod_t_prev_sqrt
=
(
1
-
alpha_prod_t_prev
).
sqrt
()
# compute relevant coefficients
coeff_1
=
(
(
alpha_prod_t_prev
-
alpha_prod_t
).
sqrt
()
*
alpha_prod_t_prev_rsqrt
*
beta_prod_t_prev_sqrt
/
beta_prod_t_sqrt
*
eta
)
coeff_2
=
((
1
-
alpha_prod_t_prev
)
-
coeff_1
**
2
).
sqrt
()
# model forward
with
torch
.
no_grad
():
train_step
=
torch
.
tensor
([
train_step
]
*
image
.
shape
[
0
],
device
=
torch_device
)
noise_residual
=
self
.
unet
(
image
,
train_step
,
context
=
text_embedding
)
# predict mean of prev image
pred_mean
=
alpha_prod_t_rsqrt
*
(
image
-
beta_prod_t_sqrt
*
noise_residual
)
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
pred_mean
=
(
1
/
alpha_prod_t_prev_rsqrt
)
*
pred_mean
+
coeff_2
*
noise_residual
# if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
if
eta
>
0.0
:
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
image
.
device
,
generator
=
generator
)
image
=
pred_mean
+
coeff_1
*
noise
else
:
image
=
pred_mean
image
=
1
/
image
image
=
self
.
vqvae
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
src/diffusers/models/unet_ldm.py
View file @
8d649443
...
...
@@ -1024,6 +1024,8 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
hs
=
[]
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
x
.
device
)
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
,
repeat_only
=
False
)
emb
=
self
.
time_embed
(
t_emb
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment