Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7ac909d6
Commit
7ac909d6
authored
Jun 10, 2022
by
patil-suraj
Browse files
make ldm work, add classifier free guidence
parent
9a1a6e97
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
76 additions
and
44 deletions
+76
-44
models/vision/latent_diffusion/modeling_latent_diffusion.py
models/vision/latent_diffusion/modeling_latent_diffusion.py
+76
-44
No files found.
models/vision/latent_diffusion/modeling_latent_diffusion.py
View file @
7ac909d6
...
...
@@ -2,10 +2,10 @@
import
math
import
numpy
as
np
import
tqdm
import
torch
import
torch.nn
as
nn
import
tqdm
from
diffusers
import
DiffusionPipeline
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.modeling_utils
import
ModelMixin
...
...
@@ -740,30 +740,29 @@ class DiagonalGaussianDistribution(object):
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
0
])
return
torch
.
Tensor
([
0.
])
else
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
else
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
])
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
0
])
return
torch
.
Tensor
([
0.
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
def
mode
(
self
):
return
self
.
mean
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
self
,
...
...
@@ -835,7 +834,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end
=
give_pre_end
,
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
def
encode
(
self
,
x
):
...
...
@@ -864,7 +863,7 @@ class LatentDiffusion(DiffusionPipeline):
super
().
__init__
()
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
...
...
@@ -873,7 +872,11 @@ class LatentDiffusion(DiffusionPipeline):
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)[
0
]
# get text embedding
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)[
0
]
...
...
@@ -886,46 +889,75 @@ class LatentDiffusion(DiffusionPipeline):
device
=
torch_device
,
generator
=
generator
,
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# get actual t and t-1
# 1. predict noise residual
if
guidance_scale
==
1.0
:
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
context
=
text_embedding
image_in
=
image
else
:
image_in
=
torch
.
cat
([
image
]
*
2
)
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
context
=
torch
.
cat
([
uncond_embeddings
,
text_embedding
])
with
torch
.
no_grad
():
pred_noise_t
=
self
.
unet
(
image_in
,
timesteps
,
context
=
context
)
if
guidance_scale
!=
1.0
:
pred_noise_t_uncond
,
pred_noise_t
=
pred_noise_t
.
chunk
(
2
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
# 2. get actual t and t-1
train_step
=
inference_step_times
[
t
]
prev_train_step
=
inference_step_times
[
t
-
1
]
if
t
>
0
else
-
1
# compute alphas
#
3.
compute alphas
, betas
alpha_prod_t
=
self
.
noise_scheduler
.
get_alpha_prod
(
train_step
)
alpha_prod_t_prev
=
self
.
noise_scheduler
.
get_alpha_prod
(
prev_train_step
)
alpha_prod_t_rsqrt
=
1
/
alpha_prod_t
.
sqrt
()
alpha_prod_t_prev_rsqrt
=
1
/
alpha_prod_t_prev
.
sqrt
()
beta_prod_t_sqrt
=
(
1
-
alpha_prod_t
).
sqrt
()
beta_prod_t_prev_sqrt
=
(
1
-
alpha_prod_t_prev
).
sqrt
()
# compute relevant coefficients
coeff_1
=
(
(
alpha_prod_t_prev
-
alpha_prod_t
).
sqrt
()
*
alpha_prod_t_prev_rsqrt
*
beta_prod_t_prev_sqrt
/
beta_prod_t_sqrt
*
eta
)
coeff_2
=
((
1
-
alpha_prod_t_prev
)
-
coeff_1
**
2
).
sqrt
()
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
# model forward
with
torch
.
no_grad
():
train_step
=
torch
.
tensor
([
train_step
]
*
image
.
shape
[
0
],
device
=
torch_device
)
noise_residual
=
self
.
unet
(
image
,
train_step
,
context
=
text_embedding
)
# 4. Compute predicted previous image from predicted noise
# First: compute predicted original image from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image
=
(
image
-
beta_prod_t
.
sqrt
()
*
pred_noise_t
)
/
alpha_prod_t
.
sqrt
()
# Second: Clip "predicted x_0"
# pred_original_image = torch.clamp(pred_original_image, -1, 1)
# Third: Compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
std_dev_t
=
(
beta_prod_t_prev
/
beta_prod_t
).
sqrt
()
*
(
1
-
alpha_prod_t
/
alpha_prod_t_prev
).
sqrt
()
std_dev_t
=
eta
*
std_dev_t
# predict mean of prev image
pred_mean
=
alpha_prod_t_rsqrt
*
(
image
-
beta_prod_t_sqrt
*
noise_residual
)
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
pred_mean
=
(
1
/
alpha_prod_t_prev_rsqrt
)
*
pred_mean
+
coeff_2
*
noise_residual
# Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
).
sqrt
()
*
pred_noise_t
# if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
# Fifth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image
=
alpha_prod_t_prev
.
sqrt
()
*
pred_original_image
+
pred_image_direction
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
# Note: eta = 1.0 essentially corresponds to DDPM
if
eta
>
0.0
:
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
image
.
device
,
generator
=
generator
)
image
=
pred_
mean
+
coeff_1
*
noise
prev_
image
=
pred_
prev_image
+
std_dev_t
*
noise
else
:
image
=
pred_mean
prev_image
=
pred_prev_image
# 6. Set current image to prev_image: x_t -> x_t-1
image
=
prev_image
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment