Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
f28cb9e1
Commit
f28cb9e1
authored
Jun 13, 2022
by
Patrick von Platen
Browse files
add dummy code for pmls
parent
1f66160e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
263 additions
and
121 deletions
+263
-121
src/diffusers/schedulers/scheduling_plms.py
src/diffusers/schedulers/scheduling_plms.py
+263
-121
No files found.
src/diffusers/schedulers/scheduling_plms.py
View file @
f28cb9e1
...
@@ -14,131 +14,273 @@
...
@@ -14,131 +14,273 @@
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
torch
from
tqdm
import
tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.schedulers_utils
import
SchedulerMixin
,
betas_for_alpha_bar
,
linear_beta_schedule
from
.schedulers_utils
import
SchedulerMixin
,
betas_for_alpha_bar
,
linear_beta_schedule
class
DDIMScheduler
(
SchedulerMixin
,
ConfigMixin
):
def
noise_like
(
shape
,
device
,
repeat
=
False
):
def
__init__
(
repeat_noise
=
lambda
:
torch
.
randn
((
1
,
*
shape
[
1
:]),
device
=
device
).
repeat
(
shape
[
0
],
*
((
1
,)
*
(
len
(
shape
)
-
1
)))
self
,
noise
=
lambda
:
torch
.
randn
(
shape
,
device
=
device
)
timesteps
=
1000
,
return
repeat_noise
()
if
repeat
else
noise
()
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
def
make_ddim_timesteps
(
ddim_discr_method
,
num_ddim_timesteps
,
num_ddpm_timesteps
,
verbose
=
True
):
clip_predicted_image
=
True
,
if
ddim_discr_method
==
'uniform'
:
tensor_format
=
"np"
,
c
=
num_ddpm_timesteps
//
num_ddim_timesteps
):
ddim_timesteps
=
np
.
asarray
(
list
(
range
(
0
,
num_ddpm_timesteps
,
c
)))
elif
ddim_discr_method
==
'quad'
:
ddim_timesteps
=
((
np
.
linspace
(
0
,
np
.
sqrt
(
num_ddpm_timesteps
*
.
8
),
num_ddim_timesteps
))
**
2
).
astype
(
int
)
else
:
raise
NotImplementedError
(
f
'There is no ddim discretization method called "
{
ddim_discr_method
}
"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out
=
ddim_timesteps
+
1
if
verbose
:
print
(
f
'Selected timesteps for ddim sampler:
{
steps_out
}
'
)
return
steps_out
def
make_ddim_sampling_parameters
(
alphacums
,
ddim_timesteps
,
eta
,
verbose
=
True
):
# select alphas for computing the variance schedule
alphas
=
alphacums
[
ddim_timesteps
]
alphas_prev
=
np
.
asarray
([
alphacums
[
0
]]
+
alphacums
[
ddim_timesteps
[:
-
1
]].
tolist
())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas
=
eta
*
np
.
sqrt
((
1
-
alphas_prev
)
/
(
1
-
alphas
)
*
(
1
-
alphas
/
alphas_prev
))
if
verbose
:
print
(
f
'Selected alphas for ddim sampler: a_t:
{
alphas
}
; a_(t-1):
{
alphas_prev
}
'
)
print
(
f
'For the chosen value of eta, which is
{
eta
}
, '
f
'this results in the following sigma_t schedule for ddim sampler
{
sigmas
}
'
)
return
sigmas
,
alphas
,
alphas_prev
class
PLMSSampler
(
object
):
def
__init__
(
self
,
model
,
schedule
=
"linear"
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
model
=
model
timesteps
=
timesteps
,
self
.
ddpm_num_timesteps
=
model
.
num_timesteps
beta_start
=
beta_start
,
self
.
schedule
=
schedule
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
def
register_buffer
(
self
,
name
,
attr
):
)
if
type
(
attr
)
==
torch
.
Tensor
:
self
.
timesteps
=
int
(
timesteps
)
if
attr
.
device
!=
torch
.
device
(
"cuda"
):
self
.
clip_image
=
clip_predicted_image
attr
=
attr
.
to
(
torch
.
device
(
"cuda"
))
setattr
(
self
,
name
,
attr
)
if
beta_schedule
==
"linear"
:
self
.
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
def
make_schedule
(
self
,
ddim_num_steps
,
ddim_discretize
=
"uniform"
,
ddim_eta
=
0.
,
verbose
=
True
):
elif
beta_schedule
==
"squaredcos_cap_v2"
:
if
ddim_eta
!=
0
:
# GLIDE cosine schedule
raise
ValueError
(
'ddim_eta must be 0 for PLMS'
)
self
.
betas
=
betas_for_alpha_bar
(
self
.
ddim_timesteps
=
make_ddim_timesteps
(
ddim_discr_method
=
ddim_discretize
,
num_ddim_timesteps
=
ddim_num_steps
,
timesteps
,
num_ddpm_timesteps
=
self
.
ddpm_num_timesteps
,
verbose
=
verbose
)
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
alphas_cumprod
=
self
.
model
.
alphas_cumprod
)
assert
alphas_cumprod
.
shape
[
0
]
==
self
.
ddpm_num_timesteps
,
'alphas have to be defined for each timestep'
to_torch
=
lambda
x
:
x
.
clone
().
detach
().
to
(
torch
.
float32
).
to
(
self
.
model
.
device
)
self
.
register_buffer
(
'betas'
,
to_torch
(
self
.
model
.
betas
))
self
.
register_buffer
(
'alphas_cumprod'
,
to_torch
(
alphas_cumprod
))
self
.
register_buffer
(
'alphas_cumprod_prev'
,
to_torch
(
self
.
model
.
alphas_cumprod_prev
))
# calculations for diffusion q(x_t | x_{t-1}) and others
self
.
register_buffer
(
'sqrt_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
'sqrt_one_minus_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
-
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
'log_one_minus_alphas_cumprod'
,
to_torch
(
np
.
log
(
1.
-
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
'sqrt_recip_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
'sqrt_recipm1_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
.
cpu
()
-
1
)))
# ddim sampling parameters
ddim_sigmas
,
ddim_alphas
,
ddim_alphas_prev
=
make_ddim_sampling_parameters
(
alphacums
=
alphas_cumprod
.
cpu
(),
ddim_timesteps
=
self
.
ddim_timesteps
,
eta
=
ddim_eta
,
verbose
=
verbose
)
self
.
register_buffer
(
'ddim_sigmas'
,
ddim_sigmas
)
self
.
register_buffer
(
'ddim_alphas'
,
ddim_alphas
)
self
.
register_buffer
(
'ddim_alphas_prev'
,
ddim_alphas_prev
)
self
.
register_buffer
(
'ddim_sqrt_one_minus_alphas'
,
np
.
sqrt
(
1.
-
ddim_alphas
))
sigmas_for_original_sampling_steps
=
ddim_eta
*
torch
.
sqrt
(
(
1
-
self
.
alphas_cumprod_prev
)
/
(
1
-
self
.
alphas_cumprod
)
*
(
1
-
self
.
alphas_cumprod
/
self
.
alphas_cumprod_prev
))
self
.
register_buffer
(
'ddim_sigmas_for_original_num_steps'
,
sigmas_for_original_sampling_steps
)
@
torch
.
no_grad
()
def
sample
(
self
,
S
,
batch_size
,
shape
,
conditioning
=
None
,
callback
=
None
,
normals_sequence
=
None
,
img_callback
=
None
,
quantize_x0
=
False
,
eta
=
0.
,
mask
=
None
,
x0
=
None
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
verbose
=
True
,
x_T
=
None
,
log_every_t
=
100
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**
kwargs
):
if
conditioning
is
not
None
:
if
isinstance
(
conditioning
,
dict
):
cbs
=
conditioning
[
list
(
conditioning
.
keys
())[
0
]].
shape
[
0
]
if
cbs
!=
batch_size
:
print
(
f
"Warning: Got
{
cbs
}
conditionings but batch-size is
{
batch_size
}
"
)
else
:
if
conditioning
.
shape
[
0
]
!=
batch_size
:
print
(
f
"Warning: Got
{
conditioning
.
shape
[
0
]
}
conditionings but batch-size is
{
batch_size
}
"
)
self
.
make_schedule
(
ddim_num_steps
=
S
,
ddim_eta
=
eta
,
verbose
=
verbose
)
# sampling
C
,
H
,
W
=
shape
size
=
(
batch_size
,
C
,
H
,
W
)
print
(
f
'Data shape for PLMS sampling is
{
size
}
'
)
samples
,
intermediates
=
self
.
plms_sampling
(
conditioning
,
size
,
callback
=
callback
,
img_callback
=
img_callback
,
quantize_denoised
=
quantize_x0
,
mask
=
mask
,
x0
=
x0
,
ddim_use_original_steps
=
False
,
noise_dropout
=
noise_dropout
,
temperature
=
temperature
,
score_corrector
=
score_corrector
,
corrector_kwargs
=
corrector_kwargs
,
x_T
=
x_T
,
log_every_t
=
log_every_t
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
)
return
samples
,
intermediates
@
torch
.
no_grad
()
def
plms_sampling
(
self
,
cond
,
shape
,
x_T
=
None
,
ddim_use_original_steps
=
False
,
callback
=
None
,
timesteps
=
None
,
quantize_denoised
=
False
,
mask
=
None
,
x0
=
None
,
img_callback
=
None
,
log_every_t
=
100
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,):
device
=
self
.
model
.
betas
.
device
b
=
shape
[
0
]
if
x_T
is
None
:
img
=
torch
.
randn
(
shape
,
device
=
device
)
else
:
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
img
=
x_T
self
.
alphas
=
1.0
-
self
.
betas
if
timesteps
is
None
:
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
timesteps
=
self
.
ddpm_num_timesteps
if
ddim_use_original_steps
else
self
.
ddim_timesteps
self
.
one
=
np
.
array
(
1.0
)
elif
timesteps
is
not
None
and
not
ddim_use_original_steps
:
subset_end
=
int
(
min
(
timesteps
/
self
.
ddim_timesteps
.
shape
[
0
],
1
)
*
self
.
ddim_timesteps
.
shape
[
0
])
-
1
self
.
set_format
(
tensor_format
=
tensor_format
)
timesteps
=
self
.
ddim_timesteps
[:
subset_end
]
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
intermediates
=
{
'x_inter'
:
[
img
],
'pred_x0'
:
[
img
]}
# TODO(PVP) - check how much of these is actually necessary!
time_range
=
list
(
reversed
(
range
(
0
,
timesteps
)))
if
ddim_use_original_steps
else
np
.
flip
(
timesteps
)
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
total_steps
=
timesteps
if
ddim_use_original_steps
else
timesteps
.
shape
[
0
]
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
print
(
f
"Running PLMS Sampling with
{
total_steps
}
timesteps"
)
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
iterator
=
tqdm
(
time_range
,
desc
=
'PLMS Sampler'
,
total
=
total_steps
)
# log_variance = torch.log(variance.clamp(min=1e-20))
old_eps
=
[]
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
for
i
,
step
in
enumerate
(
iterator
):
#
index
=
total_steps
-
i
-
1
#
ts
=
torch
.
full
((
b
,),
step
,
device
=
device
,
dtype
=
torch
.
long
)
# self.register_buffer("log_variance", log_variance.to(torch.float32))
ts_next
=
torch
.
full
((
b
,),
time_range
[
min
(
i
+
1
,
len
(
time_range
)
-
1
)],
device
=
device
,
dtype
=
torch
.
long
)
def
get_alpha
(
self
,
time_step
):
if
mask
is
not
None
:
return
self
.
alphas
[
time_step
]
assert
x0
is
not
None
img_orig
=
self
.
model
.
q_sample
(
x0
,
ts
)
# TODO: deterministic forward pass?
def
get_beta
(
self
,
time_step
):
img
=
img_orig
*
mask
+
(
1.
-
mask
)
*
img
return
self
.
betas
[
time_step
]
outs
=
self
.
p_sample_plms
(
img
,
cond
,
ts
,
index
=
index
,
use_original_steps
=
ddim_use_original_steps
,
def
get_alpha_prod
(
self
,
time_step
):
quantize_denoised
=
quantize_denoised
,
temperature
=
temperature
,
if
time_step
<
0
:
noise_dropout
=
noise_dropout
,
score_corrector
=
score_corrector
,
return
self
.
one
corrector_kwargs
=
corrector_kwargs
,
return
self
.
alphas_cumprod
[
time_step
]
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
def
get_orig_t
(
self
,
t
,
num_inference_steps
):
old_eps
=
old_eps
,
t_next
=
ts_next
)
if
t
<
0
:
img
,
pred_x0
,
e_t
=
outs
return
-
1
old_eps
.
append
(
e_t
)
return
self
.
timesteps
//
num_inference_steps
*
t
if
len
(
old_eps
)
>=
4
:
old_eps
.
pop
(
0
)
def
get_variance
(
self
,
t
,
num_inference_steps
):
if
callback
:
callback
(
i
)
orig_t
=
self
.
get_orig_t
(
t
,
num_inference_steps
)
if
img_callback
:
img_callback
(
pred_x0
,
i
)
orig_prev_t
=
self
.
get_orig_t
(
t
-
1
,
num_inference_steps
)
if
index
%
log_every_t
==
0
or
index
==
total_steps
-
1
:
alpha_prod_t
=
self
.
get_alpha_prod
(
orig_t
)
intermediates
[
'x_inter'
].
append
(
img
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
orig_prev_t
)
intermediates
[
'pred_x0'
].
append
(
pred_x0
)
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
return
img
,
intermediates
variance
=
(
beta_prod_t_prev
/
beta_prod_t
)
*
(
1
-
alpha_prod_t
/
alpha_prod_t_prev
)
@
torch
.
no_grad
()
def
p_sample_plms
(
self
,
x
,
c
,
t
,
index
,
repeat_noise
=
False
,
use_original_steps
=
False
,
quantize_denoised
=
False
,
return
variance
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
old_eps
=
None
,
t_next
=
None
):
def
step
(
self
,
residual
,
image
,
t
,
num_inference_steps
,
eta
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
def
get_model_output
(
x
,
t
):
if
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.
:
# Notation (<variable name> -> <name in paper>
e_t
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
# - pred_noise_t -> e_theta(x_t, t)
else
:
# - pred_original_image -> f_theta(x_t, t) or x_0
x_in
=
torch
.
cat
([
x
]
*
2
)
# - std_dev_t -> sigma_t
t_in
=
torch
.
cat
([
t
]
*
2
)
# - eta -> η
c_in
=
torch
.
cat
([
unconditional_conditioning
,
c
])
# - pred_image_direction -> "direction pointingc to x_t"
e_t_uncond
,
e_t
=
self
.
model
.
apply_model
(
x_in
,
t_in
,
c_in
).
chunk
(
2
)
# - pred_prev_image -> "x_t-1"
e_t
=
e_t_uncond
+
unconditional_guidance_scale
*
(
e_t
-
e_t_uncond
)
# 1. get actual t and t-1
if
score_corrector
is
not
None
:
orig_t
=
self
.
get_orig_t
(
t
,
num_inference_steps
)
assert
self
.
model
.
parameterization
==
"eps"
orig_prev_t
=
self
.
get_orig_t
(
t
-
1
,
num_inference_steps
)
e_t
=
score_corrector
.
modify_score
(
self
.
model
,
e_t
,
x
,
t
,
c
,
**
corrector_kwargs
)
# 2. compute alphas, betas
return
e_t
alpha_prod_t
=
self
.
get_alpha_prod
(
orig_t
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
orig_prev_t
)
alphas
=
self
.
model
.
alphas_cumprod
if
use_original_steps
else
self
.
ddim_alphas
beta_prod_t
=
1
-
alpha_prod_t
alphas_prev
=
self
.
model
.
alphas_cumprod_prev
if
use_original_steps
else
self
.
ddim_alphas_prev
sqrt_one_minus_alphas
=
self
.
model
.
sqrt_one_minus_alphas_cumprod
if
use_original_steps
else
self
.
ddim_sqrt_one_minus_alphas
# 3. compute predicted original image from predicted noise also called
sigmas
=
self
.
model
.
ddim_sigmas_for_original_num_steps
if
use_original_steps
else
self
.
ddim_sigmas
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image
=
(
image
-
beta_prod_t
**
(
0.5
)
*
residual
)
/
alpha_prod_t
**
(
0.5
)
def
get_x_prev_and_pred_x0
(
e_t
,
index
):
# select parameters corresponding to the currently considered timestep
# 4. Clip "predicted x_0"
a_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas
[
index
],
device
=
device
)
if
self
.
clip_image
:
a_prev
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas_prev
[
index
],
device
=
device
)
pred_original_image
=
self
.
clip
(
pred_original_image
,
-
1
,
1
)
sigma_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
sigmas
[
index
],
device
=
device
)
sqrt_one_minus_at
=
torch
.
full
((
b
,
1
,
1
,
1
),
sqrt_one_minus_alphas
[
index
],
device
=
device
)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
# current prediction for x_0
variance
=
self
.
get_variance
(
t
,
num_inference_steps
)
pred_x0
=
(
x
-
sqrt_one_minus_at
*
e_t
)
/
a_t
.
sqrt
()
std_dev_t
=
eta
*
variance
**
(
0.5
)
if
quantize_denoised
:
pred_x0
,
_
,
*
_
=
self
.
model
.
first_stage_model
.
quantize
(
pred_x0
)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# direction pointing to x_t
pred_image_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
)
**
(
0.5
)
*
residual
dir_xt
=
(
1.
-
a_prev
-
sigma_t
**
2
).
sqrt
()
*
e_t
noise
=
sigma_t
*
noise_like
(
x
.
shape
,
device
,
repeat_noise
)
*
temperature
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if
noise_dropout
>
0.
:
pred_prev_image
=
alpha_prod_t_prev
**
(
0.5
)
*
pred_original_image
+
pred_image_direction
noise
=
torch
.
nn
.
functional
.
dropout
(
noise
,
p
=
noise_dropout
)
x_prev
=
a_prev
.
sqrt
()
*
pred_x0
+
dir_xt
+
noise
return
pred_prev_image
return
x_prev
,
pred_x0
def
__len__
(
self
):
e_t
=
get_model_output
(
x
,
t
)
return
self
.
timesteps
if
len
(
old_eps
)
==
0
:
# Pseudo Improved Euler (2nd order)
x_prev
,
pred_x0
=
get_x_prev_and_pred_x0
(
e_t
,
index
)
e_t_next
=
get_model_output
(
x_prev
,
t_next
)
e_t_prime
=
(
e_t
+
e_t_next
)
/
2
elif
len
(
old_eps
)
==
1
:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime
=
(
3
*
e_t
-
old_eps
[
-
1
])
/
2
elif
len
(
old_eps
)
==
2
:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime
=
(
23
*
e_t
-
16
*
old_eps
[
-
1
]
+
5
*
old_eps
[
-
2
])
/
12
elif
len
(
old_eps
)
>=
3
:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime
=
(
55
*
e_t
-
59
*
old_eps
[
-
1
]
+
37
*
old_eps
[
-
2
]
-
9
*
old_eps
[
-
3
])
/
24
x_prev
,
pred_x0
=
get_x_prev_and_pred_x0
(
e_t_prime
,
index
)
return
x_prev
,
pred_x0
,
e_t
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment