Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
f28cb9e1
Commit
f28cb9e1
authored
Jun 13, 2022
by
Patrick von Platen
Browse files
add dummy code for pmls
parent
1f66160e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
263 additions
and
121 deletions
+263
-121
src/diffusers/schedulers/scheduling_plms.py
src/diffusers/schedulers/scheduling_plms.py
+263
-121
No files found.
src/diffusers/schedulers/scheduling_plms.py
View file @
f28cb9e1
...
@@ -14,131 +14,273 @@
...
@@ -14,131 +14,273 @@
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
torch
from
tqdm
import
tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.schedulers_utils
import
SchedulerMixin
,
betas_for_alpha_bar
,
linear_beta_schedule
from
.schedulers_utils
import
SchedulerMixin
,
betas_for_alpha_bar
,
linear_beta_schedule
class
DDIMScheduler
(
SchedulerMixin
,
ConfigMixin
):
def
noise_like
(
shape
,
device
,
repeat
=
False
):
def
__init__
(
repeat_noise
=
lambda
:
torch
.
randn
((
1
,
*
shape
[
1
:]),
device
=
device
).
repeat
(
shape
[
0
],
*
((
1
,)
*
(
len
(
shape
)
-
1
)))
self
,
noise
=
lambda
:
torch
.
randn
(
shape
,
device
=
device
)
timesteps
=
1000
,
return
repeat_noise
()
if
repeat
else
noise
()
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
def
make_ddim_timesteps
(
ddim_discr_method
,
num_ddim_timesteps
,
num_ddpm_timesteps
,
verbose
=
True
):
clip_predicted_image
=
True
,
if
ddim_discr_method
==
'uniform'
:
tensor_format
=
"np"
,
c
=
num_ddpm_timesteps
//
num_ddim_timesteps
):
ddim_timesteps
=
np
.
asarray
(
list
(
range
(
0
,
num_ddpm_timesteps
,
c
)))
elif
ddim_discr_method
==
'quad'
:
ddim_timesteps
=
((
np
.
linspace
(
0
,
np
.
sqrt
(
num_ddpm_timesteps
*
.
8
),
num_ddim_timesteps
))
**
2
).
astype
(
int
)
else
:
raise
NotImplementedError
(
f
'There is no ddim discretization method called "
{
ddim_discr_method
}
"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out
=
ddim_timesteps
+
1
if
verbose
:
print
(
f
'Selected timesteps for ddim sampler:
{
steps_out
}
'
)
return
steps_out
def
make_ddim_sampling_parameters
(
alphacums
,
ddim_timesteps
,
eta
,
verbose
=
True
):
# select alphas for computing the variance schedule
alphas
=
alphacums
[
ddim_timesteps
]
alphas_prev
=
np
.
asarray
([
alphacums
[
0
]]
+
alphacums
[
ddim_timesteps
[:
-
1
]].
tolist
())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas
=
eta
*
np
.
sqrt
((
1
-
alphas_prev
)
/
(
1
-
alphas
)
*
(
1
-
alphas
/
alphas_prev
))
if
verbose
:
print
(
f
'Selected alphas for ddim sampler: a_t:
{
alphas
}
; a_(t-1):
{
alphas_prev
}
'
)
print
(
f
'For the chosen value of eta, which is
{
eta
}
, '
f
'this results in the following sigma_t schedule for ddim sampler
{
sigmas
}
'
)
return
sigmas
,
alphas
,
alphas_prev
class
PLMSSampler
(
object
):
def
__init__
(
self
,
model
,
schedule
=
"linear"
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
model
=
model
timesteps
=
timesteps
,
self
.
ddpm_num_timesteps
=
model
.
num_timesteps
beta_start
=
beta_start
,
self
.
schedule
=
schedule
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
def
register_buffer
(
self
,
name
,
attr
):
)
if
type
(
attr
)
==
torch
.
Tensor
:
self
.
timesteps
=
int
(
timesteps
)
if
attr
.
device
!=
torch
.
device
(
"cuda"
):
self
.
clip_image
=
clip_predicted_image
attr
=
attr
.
to
(
torch
.
device
(
"cuda"
))
setattr
(
self
,
name
,
attr
)
if
beta_schedule
==
"linear"
:
self
.
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
def
make_schedule
(
self
,
ddim_num_steps
,
ddim_discretize
=
"uniform"
,
ddim_eta
=
0.
,
verbose
=
True
):
elif
beta_schedule
==
"squaredcos_cap_v2"
:
if
ddim_eta
!=
0
:
# GLIDE cosine schedule
raise
ValueError
(
'ddim_eta must be 0 for PLMS'
)
self
.
betas
=
betas_for_alpha_bar
(
self
.
ddim_timesteps
=
make_ddim_timesteps
(
ddim_discr_method
=
ddim_discretize
,
num_ddim_timesteps
=
ddim_num_steps
,
timesteps
,
num_ddpm_timesteps
=
self
.
ddpm_num_timesteps
,
verbose
=
verbose
)
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
alphas_cumprod
=
self
.
model
.
alphas_cumprod
assert
alphas_cumprod
.
shape
[
0
]
==
self
.
ddpm_num_timesteps
,
'alphas have to be defined for each timestep'
to_torch
=
lambda
x
:
x
.
clone
().
detach
().
to
(
torch
.
float32
).
to
(
self
.
model
.
device
)
self
.
register_buffer
(
'betas'
,
to_torch
(
self
.
model
.
betas
))
self
.
register_buffer
(
'alphas_cumprod'
,
to_torch
(
alphas_cumprod
))
self
.
register_buffer
(
'alphas_cumprod_prev'
,
to_torch
(
self
.
model
.
alphas_cumprod_prev
))
# calculations for diffusion q(x_t | x_{t-1}) and others
self
.
register_buffer
(
'sqrt_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
'sqrt_one_minus_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
-
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
'log_one_minus_alphas_cumprod'
,
to_torch
(
np
.
log
(
1.
-
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
'sqrt_recip_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
'sqrt_recipm1_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
.
cpu
()
-
1
)))
# ddim sampling parameters
ddim_sigmas
,
ddim_alphas
,
ddim_alphas_prev
=
make_ddim_sampling_parameters
(
alphacums
=
alphas_cumprod
.
cpu
(),
ddim_timesteps
=
self
.
ddim_timesteps
,
eta
=
ddim_eta
,
verbose
=
verbose
)
self
.
register_buffer
(
'ddim_sigmas'
,
ddim_sigmas
)
self
.
register_buffer
(
'ddim_alphas'
,
ddim_alphas
)
self
.
register_buffer
(
'ddim_alphas_prev'
,
ddim_alphas_prev
)
self
.
register_buffer
(
'ddim_sqrt_one_minus_alphas'
,
np
.
sqrt
(
1.
-
ddim_alphas
))
sigmas_for_original_sampling_steps
=
ddim_eta
*
torch
.
sqrt
(
(
1
-
self
.
alphas_cumprod_prev
)
/
(
1
-
self
.
alphas_cumprod
)
*
(
1
-
self
.
alphas_cumprod
/
self
.
alphas_cumprod_prev
))
self
.
register_buffer
(
'ddim_sigmas_for_original_num_steps'
,
sigmas_for_original_sampling_steps
)
@
torch
.
no_grad
()
def
sample
(
self
,
S
,
batch_size
,
shape
,
conditioning
=
None
,
callback
=
None
,
normals_sequence
=
None
,
img_callback
=
None
,
quantize_x0
=
False
,
eta
=
0.
,
mask
=
None
,
x0
=
None
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
verbose
=
True
,
x_T
=
None
,
log_every_t
=
100
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**
kwargs
):
if
conditioning
is
not
None
:
if
isinstance
(
conditioning
,
dict
):
cbs
=
conditioning
[
list
(
conditioning
.
keys
())[
0
]].
shape
[
0
]
if
cbs
!=
batch_size
:
print
(
f
"Warning: Got
{
cbs
}
conditionings but batch-size is
{
batch_size
}
"
)
else
:
if
conditioning
.
shape
[
0
]
!=
batch_size
:
print
(
f
"Warning: Got
{
conditioning
.
shape
[
0
]
}
conditionings but batch-size is
{
batch_size
}
"
)
self
.
make_schedule
(
ddim_num_steps
=
S
,
ddim_eta
=
eta
,
verbose
=
verbose
)
# sampling
C
,
H
,
W
=
shape
size
=
(
batch_size
,
C
,
H
,
W
)
print
(
f
'Data shape for PLMS sampling is
{
size
}
'
)
samples
,
intermediates
=
self
.
plms_sampling
(
conditioning
,
size
,
callback
=
callback
,
img_callback
=
img_callback
,
quantize_denoised
=
quantize_x0
,
mask
=
mask
,
x0
=
x0
,
ddim_use_original_steps
=
False
,
noise_dropout
=
noise_dropout
,
temperature
=
temperature
,
score_corrector
=
score_corrector
,
corrector_kwargs
=
corrector_kwargs
,
x_T
=
x_T
,
log_every_t
=
log_every_t
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
)
)
return
samples
,
intermediates
@
torch
.
no_grad
()
def
plms_sampling
(
self
,
cond
,
shape
,
x_T
=
None
,
ddim_use_original_steps
=
False
,
callback
=
None
,
timesteps
=
None
,
quantize_denoised
=
False
,
mask
=
None
,
x0
=
None
,
img_callback
=
None
,
log_every_t
=
100
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,):
device
=
self
.
model
.
betas
.
device
b
=
shape
[
0
]
if
x_T
is
None
:
img
=
torch
.
randn
(
shape
,
device
=
device
)
else
:
img
=
x_T
if
timesteps
is
None
:
timesteps
=
self
.
ddpm_num_timesteps
if
ddim_use_original_steps
else
self
.
ddim_timesteps
elif
timesteps
is
not
None
and
not
ddim_use_original_steps
:
subset_end
=
int
(
min
(
timesteps
/
self
.
ddim_timesteps
.
shape
[
0
],
1
)
*
self
.
ddim_timesteps
.
shape
[
0
])
-
1
timesteps
=
self
.
ddim_timesteps
[:
subset_end
]
intermediates
=
{
'x_inter'
:
[
img
],
'pred_x0'
:
[
img
]}
time_range
=
list
(
reversed
(
range
(
0
,
timesteps
)))
if
ddim_use_original_steps
else
np
.
flip
(
timesteps
)
total_steps
=
timesteps
if
ddim_use_original_steps
else
timesteps
.
shape
[
0
]
print
(
f
"Running PLMS Sampling with
{
total_steps
}
timesteps"
)
iterator
=
tqdm
(
time_range
,
desc
=
'PLMS Sampler'
,
total
=
total_steps
)
old_eps
=
[]
for
i
,
step
in
enumerate
(
iterator
):
index
=
total_steps
-
i
-
1
ts
=
torch
.
full
((
b
,),
step
,
device
=
device
,
dtype
=
torch
.
long
)
ts_next
=
torch
.
full
((
b
,),
time_range
[
min
(
i
+
1
,
len
(
time_range
)
-
1
)],
device
=
device
,
dtype
=
torch
.
long
)
if
mask
is
not
None
:
assert
x0
is
not
None
img_orig
=
self
.
model
.
q_sample
(
x0
,
ts
)
# TODO: deterministic forward pass?
img
=
img_orig
*
mask
+
(
1.
-
mask
)
*
img
outs
=
self
.
p_sample_plms
(
img
,
cond
,
ts
,
index
=
index
,
use_original_steps
=
ddim_use_original_steps
,
quantize_denoised
=
quantize_denoised
,
temperature
=
temperature
,
noise_dropout
=
noise_dropout
,
score_corrector
=
score_corrector
,
corrector_kwargs
=
corrector_kwargs
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
old_eps
=
old_eps
,
t_next
=
ts_next
)
img
,
pred_x0
,
e_t
=
outs
old_eps
.
append
(
e_t
)
if
len
(
old_eps
)
>=
4
:
old_eps
.
pop
(
0
)
if
callback
:
callback
(
i
)
if
img_callback
:
img_callback
(
pred_x0
,
i
)
if
index
%
log_every_t
==
0
or
index
==
total_steps
-
1
:
intermediates
[
'x_inter'
].
append
(
img
)
intermediates
[
'pred_x0'
].
append
(
pred_x0
)
return
img
,
intermediates
@
torch
.
no_grad
()
def
p_sample_plms
(
self
,
x
,
c
,
t
,
index
,
repeat_noise
=
False
,
use_original_steps
=
False
,
quantize_denoised
=
False
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
old_eps
=
None
,
t_next
=
None
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
def
get_model_output
(
x
,
t
):
if
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.
:
e_t
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
else
:
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
x_in
=
torch
.
cat
([
x
]
*
2
)
t_in
=
torch
.
cat
([
t
]
*
2
)
self
.
alphas
=
1.0
-
self
.
betas
c_in
=
torch
.
cat
([
unconditional_conditioning
,
c
])
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
e_t_uncond
,
e_t
=
self
.
model
.
apply_model
(
x_in
,
t_in
,
c_in
).
chunk
(
2
)
self
.
one
=
np
.
array
(
1.0
)
e_t
=
e_t_uncond
+
unconditional_guidance_scale
*
(
e_t
-
e_t_uncond
)
self
.
set_format
(
tensor_format
=
tensor_format
)
if
score_corrector
is
not
None
:
assert
self
.
model
.
parameterization
==
"eps"
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
e_t
=
score_corrector
.
modify_score
(
self
.
model
,
e_t
,
x
,
t
,
c
,
**
corrector_kwargs
)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
return
e_t
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
alphas
=
self
.
model
.
alphas_cumprod
if
use_original_steps
else
self
.
ddim_alphas
# if variance_type == "fixed_small":
alphas_prev
=
self
.
model
.
alphas_cumprod_prev
if
use_original_steps
else
self
.
ddim_alphas_prev
# log_variance = torch.log(variance.clamp(min=1e-20))
sqrt_one_minus_alphas
=
self
.
model
.
sqrt_one_minus_alphas_cumprod
if
use_original_steps
else
self
.
ddim_sqrt_one_minus_alphas
# elif variance_type == "fixed_large":
sigmas
=
self
.
model
.
ddim_sigmas_for_original_num_steps
if
use_original_steps
else
self
.
ddim_sigmas
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
def
get_x_prev_and_pred_x0
(
e_t
,
index
):
#
# select parameters corresponding to the currently considered timestep
# self.register_buffer("log_variance", log_variance.to(torch.float32))
a_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas
[
index
],
device
=
device
)
a_prev
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas_prev
[
index
],
device
=
device
)
def
get_alpha
(
self
,
time_step
):
sigma_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
sigmas
[
index
],
device
=
device
)
return
self
.
alphas
[
time_step
]
sqrt_one_minus_at
=
torch
.
full
((
b
,
1
,
1
,
1
),
sqrt_one_minus_alphas
[
index
],
device
=
device
)
def
get_beta
(
self
,
time_step
):
# current prediction for x_0
return
self
.
betas
[
time_step
]
pred_x0
=
(
x
-
sqrt_one_minus_at
*
e_t
)
/
a_t
.
sqrt
()
if
quantize_denoised
:
def
get_alpha_prod
(
self
,
time_step
):
pred_x0
,
_
,
*
_
=
self
.
model
.
first_stage_model
.
quantize
(
pred_x0
)
if
time_step
<
0
:
# direction pointing to x_t
return
self
.
one
dir_xt
=
(
1.
-
a_prev
-
sigma_t
**
2
).
sqrt
()
*
e_t
return
self
.
alphas_cumprod
[
time_step
]
noise
=
sigma_t
*
noise_like
(
x
.
shape
,
device
,
repeat_noise
)
*
temperature
if
noise_dropout
>
0.
:
def
get_orig_t
(
self
,
t
,
num_inference_steps
):
noise
=
torch
.
nn
.
functional
.
dropout
(
noise
,
p
=
noise_dropout
)
if
t
<
0
:
x_prev
=
a_prev
.
sqrt
()
*
pred_x0
+
dir_xt
+
noise
return
-
1
return
x_prev
,
pred_x0
return
self
.
timesteps
//
num_inference_steps
*
t
e_t
=
get_model_output
(
x
,
t
)
def
get_variance
(
self
,
t
,
num_inference_steps
):
if
len
(
old_eps
)
==
0
:
orig_t
=
self
.
get_orig_t
(
t
,
num_inference_steps
)
# Pseudo Improved Euler (2nd order)
orig_prev_t
=
self
.
get_orig_t
(
t
-
1
,
num_inference_steps
)
x_prev
,
pred_x0
=
get_x_prev_and_pred_x0
(
e_t
,
index
)
e_t_next
=
get_model_output
(
x_prev
,
t_next
)
alpha_prod_t
=
self
.
get_alpha_prod
(
orig_t
)
e_t_prime
=
(
e_t
+
e_t_next
)
/
2
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
orig_prev_t
)
elif
len
(
old_eps
)
==
1
:
beta_prod_t
=
1
-
alpha_prod_t
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
e_t_prime
=
(
3
*
e_t
-
old_eps
[
-
1
])
/
2
elif
len
(
old_eps
)
==
2
:
variance
=
(
beta_prod_t_prev
/
beta_prod_t
)
*
(
1
-
alpha_prod_t
/
alpha_prod_t_prev
)
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime
=
(
23
*
e_t
-
16
*
old_eps
[
-
1
]
+
5
*
old_eps
[
-
2
])
/
12
return
variance
elif
len
(
old_eps
)
>=
3
:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
def
step
(
self
,
residual
,
image
,
t
,
num_inference_steps
,
eta
):
e_t_prime
=
(
55
*
e_t
-
59
*
old_eps
[
-
1
]
+
37
*
old_eps
[
-
2
]
-
9
*
old_eps
[
-
3
])
/
24
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
x_prev
,
pred_x0
=
get_x_prev_and_pred_x0
(
e_t_prime
,
index
)
# Notation (<variable name> -> <name in paper>
return
x_prev
,
pred_x0
,
e_t
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
# 1. get actual t and t-1
orig_t
=
self
.
get_orig_t
(
t
,
num_inference_steps
)
orig_prev_t
=
self
.
get_orig_t
(
t
-
1
,
num_inference_steps
)
# 2. compute alphas, betas
alpha_prod_t
=
self
.
get_alpha_prod
(
orig_t
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
orig_prev_t
)
beta_prod_t
=
1
-
alpha_prod_t
# 3. compute predicted original image from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image
=
(
image
-
beta_prod_t
**
(
0.5
)
*
residual
)
/
alpha_prod_t
**
(
0.5
)
# 4. Clip "predicted x_0"
if
self
.
clip_image
:
pred_original_image
=
self
.
clip
(
pred_original_image
,
-
1
,
1
)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance
=
self
.
get_variance
(
t
,
num_inference_steps
)
std_dev_t
=
eta
*
variance
**
(
0.5
)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
)
**
(
0.5
)
*
residual
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image
=
alpha_prod_t_prev
**
(
0.5
)
*
pred_original_image
+
pred_image_direction
return
pred_prev_image
def
__len__
(
self
):
return
self
.
timesteps
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment