Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
13c5a065
"...text-generation-inference.git" did not exist on "d0841cc8eb2c27bf96c96aa78c42014ac7a9928b"
Commit
13c5a065
authored
Jun 13, 2022
by
Patrick von Platen
Browse files
add pndm
parent
73074186
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
258 additions
and
343 deletions
+258
-343
run_pndm.py
run_pndm.py
+27
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-2
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-0
src/diffusers/pipelines/pipeline_pndm.py
src/diffusers/pipelines/pipeline_pndm.py
+89
-0
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-0
src/diffusers/schedulers/scheduling_plms.py
src/diffusers/schedulers/scheduling_plms.py
+0
-341
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+138
-0
No files found.
run_pndm.py
0 → 100755
View file @
13c5a065
#!/usr/bin/env python3
from
diffusers
import
PNDM
,
UNetModel
,
PNDMScheduler
import
PIL.Image
import
numpy
as
np
import
torch
model_id
=
"fusing/ddim-celeba-hq"
model
=
UNetModel
.
from_pretrained
(
model_id
)
scheduler
=
PNDMScheduler
()
# load model and scheduler
ddpm
=
PNDM
(
unet
=
model
,
noise_scheduler
=
scheduler
)
# run pipeline in inference (sample random noise and denoise)
image
=
ddpm
()
# process image to PIL
image_processed
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image_processed
=
(
image_processed
+
1.0
)
/
2
image_processed
=
torch
.
clamp
(
image_processed
,
0.0
,
1.0
)
image_processed
=
image_processed
*
255
image_processed
=
image_processed
.
numpy
().
astype
(
np
.
uint8
)
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
# save image
image_pil
.
save
(
"/home/patrick/images/test.png"
)
src/diffusers/__init__.py
View file @
13c5a065
...
@@ -9,6 +9,6 @@ from .models.unet import UNetModel
...
@@ -9,6 +9,6 @@ from .models.unet import UNetModel
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
,
PNDM
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
SchedulerMixin
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
SchedulerMixin
,
PNDMScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
src/diffusers/pipelines/__init__.py
View file @
13c5a065
from
.pipeline_ddim
import
DDIM
from
.pipeline_ddim
import
DDIM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_pndm
import
PNDM
from
.pipeline_glide
import
GLIDE
from
.pipeline_glide
import
GLIDE
from
.pipeline_latent_diffusion
import
LatentDiffusion
from
.pipeline_latent_diffusion
import
LatentDiffusion
src/diffusers/pipelines/pipeline_pndm.py
0 → 100644
View file @
13c5a065
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
tqdm
from
..pipeline_utils
import
DiffusionPipeline
class
PNDM
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
super
().
__init__
()
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
num_trained_timesteps
=
self
.
noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
self
.
unet
.
to
(
torch_device
)
# Sample gaussian noise to begin loop
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
# generator=torch.manual_seed(0)
)
image
=
image
.
to
(
torch_device
)
seq
=
inference_step_times
seq_next
=
[
-
1
]
+
list
(
seq
[:
-
1
])
model
=
self
.
unet
ets
=
[]
for
i
,
j
in
zip
(
reversed
(
seq
),
reversed
(
seq_next
)):
t
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
i
)
t_next
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
j
)
with
torch
.
no_grad
():
t_start
,
t_end
=
t_next
,
t
img_next
,
ets
=
self
.
noise_scheduler
.
step
(
image
,
t_start
,
t_end
,
model
,
ets
)
image
=
img_next
return
image
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
# for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# 1. predict noise residual
# with torch.no_grad():
# residual = self.unet(image, inference_step_times[t])
#
# 2. predict previous mean of image x_t-1
# pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta)
#
# 3. optionally sample variance
# variance = 0
# if eta > 0:
# noise = torch.randn(image.shape, generator=generator).to(image.device)
# variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
#
# 4. set current image to prev_image: x_t -> x_t-1
# image = pred_prev_image + variance
src/diffusers/schedulers/__init__.py
View file @
13c5a065
...
@@ -19,4 +19,5 @@
...
@@ -19,4 +19,5 @@
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_pndm
import
PNDMScheduler
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
src/diffusers/schedulers/scheduling_plms.py
deleted
100644 → 0
View file @
73074186
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
import
torch
from
tqdm
import
tqdm
from
..configuration_utils
import
ConfigMixin
from
.schedulers_utils
import
SchedulerMixin
,
betas_for_alpha_bar
,
linear_beta_schedule
def
noise_like
(
shape
,
device
,
repeat
=
False
):
repeat_noise
=
lambda
:
torch
.
randn
((
1
,
*
shape
[
1
:]),
device
=
device
).
repeat
(
shape
[
0
],
*
((
1
,)
*
(
len
(
shape
)
-
1
)))
noise
=
lambda
:
torch
.
randn
(
shape
,
device
=
device
)
return
repeat_noise
()
if
repeat
else
noise
()
def
make_ddim_timesteps
(
ddim_discr_method
,
num_ddim_timesteps
,
num_ddpm_timesteps
,
verbose
=
True
):
if
ddim_discr_method
==
"uniform"
:
c
=
num_ddpm_timesteps
//
num_ddim_timesteps
ddim_timesteps
=
np
.
asarray
(
list
(
range
(
0
,
num_ddpm_timesteps
,
c
)))
elif
ddim_discr_method
==
"quad"
:
ddim_timesteps
=
((
np
.
linspace
(
0
,
np
.
sqrt
(
num_ddpm_timesteps
*
0.8
),
num_ddim_timesteps
))
**
2
).
astype
(
int
)
else
:
raise
NotImplementedError
(
f
'There is no ddim discretization method called "
{
ddim_discr_method
}
"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out
=
ddim_timesteps
+
1
if
verbose
:
print
(
f
"Selected timesteps for ddim sampler:
{
steps_out
}
"
)
return
steps_out
def
make_ddim_sampling_parameters
(
alphacums
,
ddim_timesteps
,
eta
,
verbose
=
True
):
# select alphas for computing the variance schedule
alphas
=
alphacums
[
ddim_timesteps
]
alphas_prev
=
np
.
asarray
([
alphacums
[
0
]]
+
alphacums
[
ddim_timesteps
[:
-
1
]].
tolist
())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas
=
eta
*
np
.
sqrt
((
1
-
alphas_prev
)
/
(
1
-
alphas
)
*
(
1
-
alphas
/
alphas_prev
))
if
verbose
:
print
(
f
"Selected alphas for ddim sampler: a_t:
{
alphas
}
; a_(t-1):
{
alphas_prev
}
"
)
print
(
f
"For the chosen value of eta, which is
{
eta
}
, "
f
"this results in the following sigma_t schedule for ddim sampler
{
sigmas
}
"
)
return
sigmas
,
alphas
,
alphas_prev
class
PLMSSampler
(
object
):
def
__init__
(
self
,
model
,
schedule
=
"linear"
,
**
kwargs
):
super
().
__init__
()
self
.
model
=
model
self
.
ddpm_num_timesteps
=
model
.
num_timesteps
self
.
schedule
=
schedule
def
register_buffer
(
self
,
name
,
attr
):
if
type
(
attr
)
==
torch
.
Tensor
:
if
attr
.
device
!=
torch
.
device
(
"cuda"
):
attr
=
attr
.
to
(
torch
.
device
(
"cuda"
))
setattr
(
self
,
name
,
attr
)
def
make_schedule
(
self
,
ddim_num_steps
,
ddim_discretize
=
"uniform"
,
ddim_eta
=
0.0
,
verbose
=
True
):
if
ddim_eta
!=
0
:
raise
ValueError
(
"ddim_eta must be 0 for PLMS"
)
self
.
ddim_timesteps
=
make_ddim_timesteps
(
ddim_discr_method
=
ddim_discretize
,
num_ddim_timesteps
=
ddim_num_steps
,
num_ddpm_timesteps
=
self
.
ddpm_num_timesteps
,
verbose
=
verbose
,
)
alphas_cumprod
=
self
.
model
.
alphas_cumprod
assert
alphas_cumprod
.
shape
[
0
]
==
self
.
ddpm_num_timesteps
,
"alphas have to be defined for each timestep"
to_torch
=
lambda
x
:
x
.
clone
().
detach
().
to
(
torch
.
float32
).
to
(
self
.
model
.
device
)
self
.
register_buffer
(
"betas"
,
to_torch
(
self
.
model
.
betas
))
self
.
register_buffer
(
"alphas_cumprod"
,
to_torch
(
alphas_cumprod
))
self
.
register_buffer
(
"alphas_cumprod_prev"
,
to_torch
(
self
.
model
.
alphas_cumprod_prev
))
# calculations for diffusion q(x_t | x_{t-1}) and others
self
.
register_buffer
(
"sqrt_alphas_cumprod"
,
to_torch
(
np
.
sqrt
(
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
"sqrt_one_minus_alphas_cumprod"
,
to_torch
(
np
.
sqrt
(
1.0
-
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
"log_one_minus_alphas_cumprod"
,
to_torch
(
np
.
log
(
1.0
-
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
"sqrt_recip_alphas_cumprod"
,
to_torch
(
np
.
sqrt
(
1.0
/
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
"sqrt_recipm1_alphas_cumprod"
,
to_torch
(
np
.
sqrt
(
1.0
/
alphas_cumprod
.
cpu
()
-
1
)))
# ddim sampling parameters
ddim_sigmas
,
ddim_alphas
,
ddim_alphas_prev
=
make_ddim_sampling_parameters
(
alphacums
=
alphas_cumprod
.
cpu
(),
ddim_timesteps
=
self
.
ddim_timesteps
,
eta
=
ddim_eta
,
verbose
=
verbose
)
self
.
register_buffer
(
"ddim_sigmas"
,
ddim_sigmas
)
self
.
register_buffer
(
"ddim_alphas"
,
ddim_alphas
)
self
.
register_buffer
(
"ddim_alphas_prev"
,
ddim_alphas_prev
)
self
.
register_buffer
(
"ddim_sqrt_one_minus_alphas"
,
np
.
sqrt
(
1.0
-
ddim_alphas
))
sigmas_for_original_sampling_steps
=
ddim_eta
*
torch
.
sqrt
(
(
1
-
self
.
alphas_cumprod_prev
)
/
(
1
-
self
.
alphas_cumprod
)
*
(
1
-
self
.
alphas_cumprod
/
self
.
alphas_cumprod_prev
)
)
self
.
register_buffer
(
"ddim_sigmas_for_original_num_steps"
,
sigmas_for_original_sampling_steps
)
@
torch
.
no_grad
()
def
sample
(
self
,
S
,
batch_size
,
shape
,
conditioning
=
None
,
callback
=
None
,
normals_sequence
=
None
,
img_callback
=
None
,
quantize_x0
=
False
,
eta
=
0.0
,
mask
=
None
,
x0
=
None
,
temperature
=
1.0
,
noise_dropout
=
0.0
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
verbose
=
True
,
x_T
=
None
,
log_every_t
=
100
,
unconditional_guidance_scale
=
1.0
,
unconditional_conditioning
=
None
,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**
kwargs
,
):
if
conditioning
is
not
None
:
if
isinstance
(
conditioning
,
dict
):
cbs
=
conditioning
[
list
(
conditioning
.
keys
())[
0
]].
shape
[
0
]
if
cbs
!=
batch_size
:
print
(
f
"Warning: Got
{
cbs
}
conditionings but batch-size is
{
batch_size
}
"
)
else
:
if
conditioning
.
shape
[
0
]
!=
batch_size
:
print
(
f
"Warning: Got
{
conditioning
.
shape
[
0
]
}
conditionings but batch-size is
{
batch_size
}
"
)
self
.
make_schedule
(
ddim_num_steps
=
S
,
ddim_eta
=
eta
,
verbose
=
verbose
)
# sampling
C
,
H
,
W
=
shape
size
=
(
batch_size
,
C
,
H
,
W
)
print
(
f
"Data shape for PLMS sampling is
{
size
}
"
)
samples
,
intermediates
=
self
.
plms_sampling
(
conditioning
,
size
,
callback
=
callback
,
img_callback
=
img_callback
,
quantize_denoised
=
quantize_x0
,
mask
=
mask
,
x0
=
x0
,
ddim_use_original_steps
=
False
,
noise_dropout
=
noise_dropout
,
temperature
=
temperature
,
score_corrector
=
score_corrector
,
corrector_kwargs
=
corrector_kwargs
,
x_T
=
x_T
,
log_every_t
=
log_every_t
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
)
return
samples
,
intermediates
@
torch
.
no_grad
()
def
plms_sampling
(
self
,
cond
,
shape
,
x_T
=
None
,
ddim_use_original_steps
=
False
,
callback
=
None
,
timesteps
=
None
,
quantize_denoised
=
False
,
mask
=
None
,
x0
=
None
,
img_callback
=
None
,
log_every_t
=
100
,
temperature
=
1.0
,
noise_dropout
=
0.0
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.0
,
unconditional_conditioning
=
None
,
):
device
=
self
.
model
.
betas
.
device
b
=
shape
[
0
]
if
x_T
is
None
:
img
=
torch
.
randn
(
shape
,
device
=
device
)
else
:
img
=
x_T
if
timesteps
is
None
:
timesteps
=
self
.
ddpm_num_timesteps
if
ddim_use_original_steps
else
self
.
ddim_timesteps
elif
timesteps
is
not
None
and
not
ddim_use_original_steps
:
subset_end
=
int
(
min
(
timesteps
/
self
.
ddim_timesteps
.
shape
[
0
],
1
)
*
self
.
ddim_timesteps
.
shape
[
0
])
-
1
timesteps
=
self
.
ddim_timesteps
[:
subset_end
]
intermediates
=
{
"x_inter"
:
[
img
],
"pred_x0"
:
[
img
]}
time_range
=
list
(
reversed
(
range
(
0
,
timesteps
)))
if
ddim_use_original_steps
else
np
.
flip
(
timesteps
)
total_steps
=
timesteps
if
ddim_use_original_steps
else
timesteps
.
shape
[
0
]
print
(
f
"Running PLMS Sampling with
{
total_steps
}
timesteps"
)
iterator
=
tqdm
(
time_range
,
desc
=
"PLMS Sampler"
,
total
=
total_steps
)
old_eps
=
[]
for
i
,
step
in
enumerate
(
iterator
):
index
=
total_steps
-
i
-
1
ts
=
torch
.
full
((
b
,),
step
,
device
=
device
,
dtype
=
torch
.
long
)
ts_next
=
torch
.
full
((
b
,),
time_range
[
min
(
i
+
1
,
len
(
time_range
)
-
1
)],
device
=
device
,
dtype
=
torch
.
long
)
if
mask
is
not
None
:
assert
x0
is
not
None
img_orig
=
self
.
model
.
q_sample
(
x0
,
ts
)
# TODO: deterministic forward pass?
img
=
img_orig
*
mask
+
(
1.0
-
mask
)
*
img
outs
=
self
.
p_sample_plms
(
img
,
cond
,
ts
,
index
=
index
,
use_original_steps
=
ddim_use_original_steps
,
quantize_denoised
=
quantize_denoised
,
temperature
=
temperature
,
noise_dropout
=
noise_dropout
,
score_corrector
=
score_corrector
,
corrector_kwargs
=
corrector_kwargs
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
old_eps
=
old_eps
,
t_next
=
ts_next
,
)
img
,
pred_x0
,
e_t
=
outs
old_eps
.
append
(
e_t
)
if
len
(
old_eps
)
>=
4
:
old_eps
.
pop
(
0
)
if
callback
:
callback
(
i
)
if
img_callback
:
img_callback
(
pred_x0
,
i
)
if
index
%
log_every_t
==
0
or
index
==
total_steps
-
1
:
intermediates
[
"x_inter"
].
append
(
img
)
intermediates
[
"pred_x0"
].
append
(
pred_x0
)
return
img
,
intermediates
@
torch
.
no_grad
()
def
p_sample_plms
(
self
,
x
,
c
,
t
,
index
,
repeat_noise
=
False
,
use_original_steps
=
False
,
quantize_denoised
=
False
,
temperature
=
1.0
,
noise_dropout
=
0.0
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.0
,
unconditional_conditioning
=
None
,
old_eps
=
None
,
t_next
=
None
,
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
def
get_model_output
(
x
,
t
):
if
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.0
:
e_t
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
else
:
x_in
=
torch
.
cat
([
x
]
*
2
)
t_in
=
torch
.
cat
([
t
]
*
2
)
c_in
=
torch
.
cat
([
unconditional_conditioning
,
c
])
e_t_uncond
,
e_t
=
self
.
model
.
apply_model
(
x_in
,
t_in
,
c_in
).
chunk
(
2
)
e_t
=
e_t_uncond
+
unconditional_guidance_scale
*
(
e_t
-
e_t_uncond
)
if
score_corrector
is
not
None
:
assert
self
.
model
.
parameterization
==
"eps"
e_t
=
score_corrector
.
modify_score
(
self
.
model
,
e_t
,
x
,
t
,
c
,
**
corrector_kwargs
)
return
e_t
alphas
=
self
.
model
.
alphas_cumprod
if
use_original_steps
else
self
.
ddim_alphas
alphas_prev
=
self
.
model
.
alphas_cumprod_prev
if
use_original_steps
else
self
.
ddim_alphas_prev
sqrt_one_minus_alphas
=
(
self
.
model
.
sqrt_one_minus_alphas_cumprod
if
use_original_steps
else
self
.
ddim_sqrt_one_minus_alphas
)
sigmas
=
self
.
model
.
ddim_sigmas_for_original_num_steps
if
use_original_steps
else
self
.
ddim_sigmas
def
get_x_prev_and_pred_x0
(
e_t
,
index
):
# select parameters corresponding to the currently considered timestep
a_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas
[
index
],
device
=
device
)
a_prev
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas_prev
[
index
],
device
=
device
)
sigma_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
sigmas
[
index
],
device
=
device
)
sqrt_one_minus_at
=
torch
.
full
((
b
,
1
,
1
,
1
),
sqrt_one_minus_alphas
[
index
],
device
=
device
)
# current prediction for x_0
pred_x0
=
(
x
-
sqrt_one_minus_at
*
e_t
)
/
a_t
.
sqrt
()
if
quantize_denoised
:
pred_x0
,
_
,
*
_
=
self
.
model
.
first_stage_model
.
quantize
(
pred_x0
)
# direction pointing to x_t
dir_xt
=
(
1.0
-
a_prev
-
sigma_t
**
2
).
sqrt
()
*
e_t
noise
=
sigma_t
*
noise_like
(
x
.
shape
,
device
,
repeat_noise
)
*
temperature
if
noise_dropout
>
0.0
:
noise
=
torch
.
nn
.
functional
.
dropout
(
noise
,
p
=
noise_dropout
)
x_prev
=
a_prev
.
sqrt
()
*
pred_x0
+
dir_xt
+
noise
return
x_prev
,
pred_x0
e_t
=
get_model_output
(
x
,
t
)
if
len
(
old_eps
)
==
0
:
# Pseudo Improved Euler (2nd order)
x_prev
,
pred_x0
=
get_x_prev_and_pred_x0
(
e_t
,
index
)
e_t_next
=
get_model_output
(
x_prev
,
t_next
)
e_t_prime
=
(
e_t
+
e_t_next
)
/
2
elif
len
(
old_eps
)
==
1
:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime
=
(
3
*
e_t
-
old_eps
[
-
1
])
/
2
elif
len
(
old_eps
)
==
2
:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime
=
(
23
*
e_t
-
16
*
old_eps
[
-
1
]
+
5
*
old_eps
[
-
2
])
/
12
elif
len
(
old_eps
)
>=
3
:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime
=
(
55
*
e_t
-
59
*
old_eps
[
-
1
]
+
37
*
old_eps
[
-
2
]
-
9
*
old_eps
[
-
3
])
/
24
x_prev
,
pred_x0
=
get_x_prev_and_pred_x0
(
e_t_prime
,
index
)
return
x_prev
,
pred_x0
,
e_t
src/diffusers/schedulers/scheduling_pndm.py
0 → 100644
View file @
13c5a065
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
,
betas_for_alpha_bar
,
linear_beta_schedule
class
PNDMScheduler
(
SchedulerMixin
,
ConfigMixin
):
def
__init__
(
self
,
timesteps
=
1000
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
tensor_format
=
"np"
,
):
super
().
__init__
()
self
.
register
(
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
)
self
.
timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"linear"
:
self
.
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
self
.
betas
=
betas_for_alpha_bar
(
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
self
.
alphas
=
1.0
-
self
.
betas
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
self
.
one
=
np
.
array
(
1.0
)
self
.
set_format
(
tensor_format
=
tensor_format
)
# self.register_buffer("betas", betas.to(torch.float32))
# self.register_buffer("alphas", alphas.to(torch.float32))
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
# log_variance = torch.log(variance.clamp(min=1e-20))
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def
get_alpha
(
self
,
time_step
):
return
self
.
alphas
[
time_step
]
def
get_beta
(
self
,
time_step
):
return
self
.
betas
[
time_step
]
def
get_alpha_prod
(
self
,
time_step
):
if
time_step
<
0
:
return
self
.
one
return
self
.
alphas_cumprod
[
time_step
]
def
step
(
self
,
img
,
t_start
,
t_end
,
model
,
ets
):
# img_next = self.method(img_n, t_start, t_end, model, self.alphas_cump, self.ets)
#def gen_order_4(img, t, t_next, model, alphas_cump, ets):
t_next
,
t
=
t_start
,
t_end
t_list
=
[
t
,
(
t
+
t_next
)
/
2
,
t_next
]
alphas_cump
=
self
.
alphas_cumprod
if
len
(
ets
)
>
2
:
noise_
=
model
(
img
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
noise_
=
noise_
.
to
(
"cpu"
)
ets
.
append
(
noise_
)
noise
=
(
1
/
24
)
*
(
55
*
ets
[
-
1
]
-
59
*
ets
[
-
2
]
+
37
*
ets
[
-
3
]
-
9
*
ets
[
-
4
])
else
:
noise
=
self
.
runge_kutta
(
img
,
t_list
,
model
,
alphas_cump
,
ets
)
img_next
=
self
.
transfer
(
img
.
to
(
"cpu"
),
t
,
t_next
,
noise
,
alphas_cump
)
return
img_next
,
ets
def
runge_kutta
(
self
,
x
,
t_list
,
model
,
alphas_cump
,
ets
):
model
=
model
.
to
(
"cuda"
)
x
=
x
.
to
(
"cpu"
)
e_1
=
model
(
x
.
to
(
"cuda"
),
t_list
[
0
].
to
(
"cuda"
))
e_1
=
e_1
.
to
(
"cpu"
)
ets
.
append
(
e_1
)
x_2
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
1
],
e_1
,
alphas_cump
)
e_2
=
model
(
x_2
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
))
e_2
=
e_2
.
to
(
"cpu"
)
x_3
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
1
],
e_2
,
alphas_cump
)
e_3
=
model
(
x_3
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
))
e_3
=
e_3
.
to
(
"cpu"
)
x_4
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
2
],
e_3
,
alphas_cump
)
e_4
=
model
(
x_4
.
to
(
"cuda"
),
t_list
[
2
].
to
(
"cuda"
))
e_4
=
e_4
.
to
(
"cpu"
)
et
=
(
1
/
6
)
*
(
e_1
+
2
*
e_2
+
2
*
e_3
+
e_4
)
return
et
def
transfer
(
self
,
x
,
t
,
t_next
,
et
,
alphas_cump
):
at
=
alphas_cump
[
t
.
long
()
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
.
long
()
+
1
].
view
(
-
1
,
1
,
1
,
1
)
x_delta
=
(
at_next
-
at
)
*
((
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
)
x_next
=
x
+
x_delta
return
x_next
def
__len__
(
self
):
return
self
.
timesteps
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment