Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4e3f4a9e
Commit
4e3f4a9e
authored
Jun 10, 2022
by
patil-suraj
Browse files
cleanup LDM
parent
a7584637
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
14 deletions
+17
-14
models/vision/latent_diffusion/modeling_latent_diffusion.py
models/vision/latent_diffusion/modeling_latent_diffusion.py
+17
-14
No files found.
models/vision/latent_diffusion/modeling_latent_diffusion.py
View file @
4e3f4a9e
...
@@ -863,6 +863,7 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -863,6 +863,7 @@ class LatentDiffusion(DiffusionPipeline):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
):
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
):
# eta corresponds to η in paper and should be between [0, 1]
# eta corresponds to η in paper and should be between [0, 1]
...
@@ -873,6 +874,7 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -873,6 +874,7 @@ class LatentDiffusion(DiffusionPipeline):
self
.
vqvae
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
# get unconditional embeddings for classifier free guidence
if
guidance_scale
!=
1.0
:
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)[
0
]
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)[
0
]
...
@@ -901,19 +903,23 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -901,19 +903,23 @@ class LatentDiffusion(DiffusionPipeline):
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
#
1. predict noise residual
#
guidance_scale of 1 means no guidance
if
guidance_scale
==
1.0
:
if
guidance_scale
==
1.0
:
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
context
=
text_embedding
image_in
=
image
image_in
=
image
context
=
text_embedding
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
else
:
else
:
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
image_in
=
torch
.
cat
([
image
]
*
2
)
image_in
=
torch
.
cat
([
image
]
*
2
)
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
context
=
torch
.
cat
([
uncond_embeddings
,
text_embedding
])
context
=
torch
.
cat
([
uncond_embeddings
,
text_embedding
])
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
# 1. predict noise residual
pred_noise_t
=
self
.
unet
(
image_in
,
timesteps
,
context
=
context
)
with
torch
.
no_grad
():
# perform guidance
pred_noise_t
=
self
.
unet
(
image_in
,
timesteps
,
context
=
context
)
if
guidance_scale
!=
1.0
:
if
guidance_scale
!=
1.0
:
pred_noise_t_uncond
,
pred_noise_t
=
pred_noise_t
.
chunk
(
2
)
pred_noise_t_uncond
,
pred_noise_t
=
pred_noise_t
.
chunk
(
2
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
...
@@ -933,18 +939,15 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -933,18 +939,15 @@ class LatentDiffusion(DiffusionPipeline):
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image
=
(
image
-
beta_prod_t
.
sqrt
()
*
pred_noise_t
)
/
alpha_prod_t
.
sqrt
()
pred_original_image
=
(
image
-
beta_prod_t
.
sqrt
()
*
pred_noise_t
)
/
alpha_prod_t
.
sqrt
()
# Second: Clip "predicted x_0"
# Second: Compute variance: "sigma_t(η)" -> see formula (16)
# pred_original_image = torch.clamp(pred_original_image, -1, 1)
# Third: Compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
std_dev_t
=
(
beta_prod_t_prev
/
beta_prod_t
).
sqrt
()
*
(
1
-
alpha_prod_t
/
alpha_prod_t_prev
).
sqrt
()
std_dev_t
=
(
beta_prod_t_prev
/
beta_prod_t
).
sqrt
()
*
(
1
-
alpha_prod_t
/
alpha_prod_t_prev
).
sqrt
()
std_dev_t
=
eta
*
std_dev_t
std_dev_t
=
eta
*
std_dev_t
#
Fourth
: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
#
Third
: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
).
sqrt
()
*
pred_noise_t
pred_image_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
).
sqrt
()
*
pred_noise_t
# F
if
th: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# F
or
th: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image
=
alpha_prod_t_prev
.
sqrt
()
*
pred_original_image
+
pred_image_direction
pred_prev_image
=
alpha_prod_t_prev
.
sqrt
()
*
pred_original_image
+
pred_image_direction
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
...
@@ -958,9 +961,9 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -958,9 +961,9 @@ class LatentDiffusion(DiffusionPipeline):
# 6. Set current image to prev_image: x_t -> x_t-1
# 6. Set current image to prev_image: x_t -> x_t-1
image
=
prev_image
image
=
prev_image
# scale and decode image with vae
image
=
1
/
0.18215
*
image
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
image
=
255.
*
image
return
image
return
image
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment