Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7ac909d6
Commit
7ac909d6
authored
Jun 10, 2022
by
patil-suraj
Browse files
make ldm work, add classifier free guidence
parent
9a1a6e97
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
76 additions
and
44 deletions
+76
-44
models/vision/latent_diffusion/modeling_latent_diffusion.py
models/vision/latent_diffusion/modeling_latent_diffusion.py
+76
-44
No files found.
models/vision/latent_diffusion/modeling_latent_diffusion.py
View file @
7ac909d6
...
@@ -2,10 +2,10 @@
...
@@ -2,10 +2,10 @@
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
tqdm
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
tqdm
from
diffusers
import
DiffusionPipeline
from
diffusers
import
DiffusionPipeline
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.modeling_utils
import
ModelMixin
from
diffusers.modeling_utils
import
ModelMixin
...
@@ -740,30 +740,29 @@ class DiagonalGaussianDistribution(object):
...
@@ -740,30 +740,29 @@ class DiagonalGaussianDistribution(object):
def
kl
(
self
,
other
=
None
):
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
0
])
return
torch
.
Tensor
([
0.
])
else
:
else
:
if
other
is
None
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
else
:
else
:
return
0.5
*
torch
.
sum
(
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
-
1.0
dim
=
[
1
,
2
,
3
])
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
0
])
return
torch
.
Tensor
([
0.
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
def
mode
(
self
):
def
mode
(
self
):
return
self
.
mean
return
self
.
mean
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -835,7 +834,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -835,7 +834,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end
=
give_pre_end
,
give_pre_end
=
give_pre_end
,
)
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
def
encode
(
self
,
x
):
def
encode
(
self
,
x
):
...
@@ -864,7 +863,7 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -864,7 +863,7 @@ class LatentDiffusion(DiffusionPipeline):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
):
# eta corresponds to η in paper and should be between [0, 1]
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
if
torch_device
is
None
:
...
@@ -874,6 +873,10 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -874,6 +873,10 @@ class LatentDiffusion(DiffusionPipeline):
self
.
vqvae
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)[
0
]
# get text embedding
# get text embedding
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)[
0
]
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)[
0
]
...
@@ -886,45 +889,74 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -886,45 +889,74 @@ class LatentDiffusion(DiffusionPipeline):
device
=
torch_device
,
device
=
torch_device
,
generator
=
generator
,
generator
=
generator
,
)
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# get actual t and t-1
# 1. predict noise residual
if
guidance_scale
==
1.0
:
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
context
=
text_embedding
image_in
=
image
else
:
image_in
=
torch
.
cat
([
image
]
*
2
)
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
context
=
torch
.
cat
([
uncond_embeddings
,
text_embedding
])
with
torch
.
no_grad
():
pred_noise_t
=
self
.
unet
(
image_in
,
timesteps
,
context
=
context
)
if
guidance_scale
!=
1.0
:
pred_noise_t_uncond
,
pred_noise_t
=
pred_noise_t
.
chunk
(
2
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
# 2. get actual t and t-1
train_step
=
inference_step_times
[
t
]
train_step
=
inference_step_times
[
t
]
prev_train_step
=
inference_step_times
[
t
-
1
]
if
t
>
0
else
-
1
prev_train_step
=
inference_step_times
[
t
-
1
]
if
t
>
0
else
-
1
# compute alphas
#
3.
compute alphas
, betas
alpha_prod_t
=
self
.
noise_scheduler
.
get_alpha_prod
(
train_step
)
alpha_prod_t
=
self
.
noise_scheduler
.
get_alpha_prod
(
train_step
)
alpha_prod_t_prev
=
self
.
noise_scheduler
.
get_alpha_prod
(
prev_train_step
)
alpha_prod_t_prev
=
self
.
noise_scheduler
.
get_alpha_prod
(
prev_train_step
)
alpha_prod_t_rsqrt
=
1
/
alpha_prod_t
.
sqrt
()
beta_prod_t
=
1
-
alpha_prod_t
alpha_prod_t_prev_rsqrt
=
1
/
alpha_prod_t_prev
.
sqrt
()
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
beta_prod_t_sqrt
=
(
1
-
alpha_prod_t
).
sqrt
()
beta_prod_t_prev_sqrt
=
(
1
-
alpha_prod_t_prev
).
sqrt
()
# compute relevant coefficients
coeff_1
=
(
(
alpha_prod_t_prev
-
alpha_prod_t
).
sqrt
()
*
alpha_prod_t_prev_rsqrt
*
beta_prod_t_prev_sqrt
/
beta_prod_t_sqrt
*
eta
)
coeff_2
=
((
1
-
alpha_prod_t_prev
)
-
coeff_1
**
2
).
sqrt
()
#
model forward
#
4. Compute predicted previous image from predicted noise
with
torch
.
no_grad
():
# First: compute predicted original image from predicted noise also called
train_step
=
torch
.
tensor
([
train_step
]
*
image
.
shape
[
0
],
device
=
torch_device
)
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
noise_residual
=
self
.
unet
(
image
,
train_step
,
context
=
text_embedding
)
pred_original_image
=
(
image
-
beta_prod_t
.
sqrt
()
*
pred_noise_t
)
/
alpha_prod_t
.
sqrt
(
)
# predict mean of prev image
# Second: Clip "predicted x_0"
pred_mean
=
alpha_prod_t_rsqrt
*
(
image
-
beta_prod_t_sqrt
*
noise_residual
)
# pred_original_image = torch.clamp(pred_original_image, -1, 1)
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
pred_mean
=
(
1
/
alpha_prod_t_prev_rsqrt
)
*
pred_mean
+
coeff_2
*
noise_residual
# if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
# Third: Compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
std_dev_t
=
(
beta_prod_t_prev
/
beta_prod_t
).
sqrt
()
*
(
1
-
alpha_prod_t
/
alpha_prod_t_prev
).
sqrt
()
std_dev_t
=
eta
*
std_dev_t
# Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
).
sqrt
()
*
pred_noise_t
# Fifth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image
=
alpha_prod_t_prev
.
sqrt
()
*
pred_original_image
+
pred_image_direction
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
# Note: eta = 1.0 essentially corresponds to DDPM
if
eta
>
0.0
:
if
eta
>
0.0
:
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
image
.
device
,
generator
=
generator
)
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
image
.
device
,
generator
=
generator
)
image
=
pred_
mean
+
coeff_1
*
noise
prev_
image
=
pred_
prev_image
+
std_dev_t
*
noise
else
:
else
:
image
=
pred_mean
prev_image
=
pred_prev_image
# 6. Set current image to prev_image: x_t -> x_t-1
image
=
prev_image
image
=
1
/
0.18215
*
image
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
self
.
vqvae
.
decode
(
image
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment