Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
f448360b
Unverified
Commit
f448360b
authored
Jul 15, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 15, 2022
Browse files
Finish scheduler API (#91)
* finish * up
parent
97e1e3ba
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
233 additions
and
188 deletions
+233
-188
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+4
-4
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+2
-2
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+5
-3
src/diffusers/pipelines/ddim/pipeline_ddim.py
src/diffusers/pipelines/ddim/pipeline_ddim.py
+11
-29
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+6
-6
src/diffusers/pipelines/glide/pipeline_glide.py
src/diffusers/pipelines/glide/pipeline_glide.py
+11
-13
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+6
-6
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
...tent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+12
-33
src/diffusers/pipelines/pndm/pipeline_pndm.py
src/diffusers/pipelines/pndm/pipeline_pndm.py
+7
-7
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+43
-15
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+11
-2
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+20
-4
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+20
-21
tests/test_scheduler.py
tests/test_scheduler.py
+75
-43
No files found.
src/diffusers/models/unet_glide.py
View file @
f448360b
...
...
@@ -438,8 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel):
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
def
forward
(
self
,
sample
,
step
_value
,
transformer_out
=
None
):
timesteps
=
step
_value
def
forward
(
self
,
sample
,
time
step
,
transformer_out
=
None
):
timesteps
=
time
step
x
=
sample
hs
=
[]
emb
=
self
.
time_embed
(
...
...
@@ -530,8 +530,8 @@ class GlideSuperResUNetModel(GlideUNetModel):
resblock_updown
=
resblock_updown
,
)
def
forward
(
self
,
sample
,
step
_value
,
low_res
=
None
):
timesteps
=
step
_value
def
forward
(
self
,
sample
,
time
step
,
low_res
=
None
):
timesteps
=
time
step
x
=
sample
_
,
_
,
new_height
,
new_width
=
x
.
shape
upsampled
=
F
.
interpolate
(
low_res
,
(
new_height
,
new_width
),
mode
=
"bilinear"
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
f448360b
...
...
@@ -323,8 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
def
forward
(
self
,
sample
,
step
_value
,
sigmas
=
None
):
timesteps
=
step
_value
def
forward
(
self
,
sample
,
time
step
,
sigmas
=
None
):
timesteps
=
time
step
x
=
sample
# timestep/noise_level embedding; only for continuous training
modules
=
self
.
all_modules
...
...
src/diffusers/models/unet_unconditional.py
View file @
f448360b
...
...
@@ -254,7 +254,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# ======================================
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
step
_value
:
Union
[
torch
.
Tensor
,
float
,
int
]
self
,
sample
:
torch
.
FloatTensor
,
time
step
:
Union
[
torch
.
Tensor
,
float
,
int
]
)
->
Dict
[
str
,
torch
.
FloatTensor
]:
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
...
...
@@ -263,10 +263,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
set_weights
()
# ======================================
# 1. time step embeddings
timesteps
=
step
_value
# 1. time step embeddings
-> make correct tensor
timesteps
=
time
step
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
elif
torch
.
is_tensor
(
timesteps
)
and
len
(
timesteps
.
shape
)
==
0
:
timesteps
=
timesteps
[
None
].
to
(
sample
.
device
)
t_emb
=
get_timestep_embedding
(
timesteps
,
...
...
src/diffusers/pipelines/ddim/pipeline_ddim.py
View file @
f448360b
...
...
@@ -22,19 +22,16 @@ from ...pipeline_utils import DiffusionPipeline
class
DDIMPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
unet
,
scheduler
):
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
num_trained_timesteps
=
self
.
noise_scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
self
.
unet
.
to
(
torch_device
)
# Sample gaussian noise to begin loop
...
...
@@ -44,34 +41,19 @@ class DDIMPipeline(DiffusionPipeline):
)
image
=
image
.
to
(
torch_device
)
#
See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
#
set step values
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
for
t
in
tqdm
.
tqdm
(
self
.
scheduler
.
timesteps
):
# 1. predict noise residual
with
torch
.
no_grad
():
residual
=
self
.
unet
(
image
,
inference_step_times
[
t
]
)
residual
=
self
.
unet
(
image
,
t
)
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
,
num_inference_steps
,
eta
)
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image
=
self
.
scheduler
.
step
(
residual
,
t
,
image
,
eta
)[
"prev_sample"
]
return
image
return
{
"sample"
:
image
}
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
f448360b
...
...
@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
class
DDPMPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
unet
,
scheduler
):
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
):
if
torch_device
is
None
:
...
...
@@ -40,7 +40,7 @@ class DDPMPipeline(DiffusionPipeline):
)
image
=
image
.
to
(
torch_device
)
num_prediction_steps
=
len
(
self
.
noise_
scheduler
)
num_prediction_steps
=
len
(
self
.
scheduler
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_prediction_steps
)),
total
=
num_prediction_steps
):
# 1. predict noise residual
with
torch
.
no_grad
():
...
...
@@ -50,13 +50,13 @@ class DDPMPipeline(DiffusionPipeline):
residual
=
residual
[
"sample"
]
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_
scheduler
.
step
(
residual
,
image
,
t
)
pred_prev_image
=
self
.
scheduler
.
step
(
residual
,
t
,
image
)[
"prev_sample"
]
# 3. optionally sample variance
variance
=
0
if
t
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_
scheduler
.
get_variance
(
t
).
sqrt
()
*
noise
variance
=
self
.
scheduler
.
get_variance
(
t
).
sqrt
()
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
...
...
src/diffusers/pipelines/glide/pipeline_glide.py
View file @
f448360b
...
...
@@ -713,20 +713,20 @@ class GlidePipeline(DiffusionPipeline):
def
__init__
(
self
,
text_unet
:
GlideTextToImageUNetModel
,
text_
noise_
scheduler
:
DDPMScheduler
,
text_scheduler
:
DDPMScheduler
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
GPT2Tokenizer
,
upscale_unet
:
GlideSuperResUNetModel
,
upscale_
noise_
scheduler
:
DDIMScheduler
,
upscale_scheduler
:
DDIMScheduler
,
):
super
().
__init__
()
self
.
register_modules
(
text_unet
=
text_unet
,
text_
noise_
scheduler
=
text_
noise_
scheduler
,
text_scheduler
=
text_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
upscale_unet
=
upscale_unet
,
upscale_
noise_
scheduler
=
upscale_
noise_
scheduler
,
upscale_scheduler
=
upscale_scheduler
,
)
@
torch
.
no_grad
()
...
...
@@ -777,20 +777,20 @@ class GlidePipeline(DiffusionPipeline):
transformer_out
=
self
.
text_encoder
(
input_ids
,
attention_mask
).
last_hidden_state
# 3. Run the text2image generation step
num_prediction_steps
=
len
(
self
.
text_
noise_
scheduler
)
num_prediction_steps
=
len
(
self
.
text_scheduler
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_prediction_steps
)),
total
=
num_prediction_steps
):
with
torch
.
no_grad
():
time_input
=
torch
.
tensor
([
t
]
*
image
.
shape
[
0
],
device
=
torch_device
)
model_output
=
text_model_fn
(
image
,
time_input
,
transformer_out
)
noise_residual
,
model_var_values
=
torch
.
split
(
model_output
,
3
,
dim
=
1
)
min_log
=
self
.
text_
noise_
scheduler
.
get_variance
(
t
,
"fixed_small_log"
)
max_log
=
self
.
text_
noise_
scheduler
.
get_variance
(
t
,
"fixed_large_log"
)
min_log
=
self
.
text_scheduler
.
get_variance
(
t
,
"fixed_small_log"
)
max_log
=
self
.
text_scheduler
.
get_variance
(
t
,
"fixed_large_log"
)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac
=
(
model_var_values
+
1
)
/
2
model_log_variance
=
frac
*
max_log
+
(
1
-
frac
)
*
min_log
pred_prev_image
=
self
.
text_
noise_
scheduler
.
step
(
noise_residual
,
image
,
t
)
pred_prev_image
=
self
.
text_scheduler
.
step
(
noise_residual
,
image
,
t
)
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
torch_device
)
variance
=
torch
.
exp
(
0.5
*
model_log_variance
)
*
noise
...
...
@@ -814,7 +814,7 @@ class GlidePipeline(DiffusionPipeline):
).
to
(
torch_device
)
image
=
image
*
upsample_temp
num_trained_timesteps
=
self
.
upscale_
noise_
scheduler
.
timesteps
num_trained_timesteps
=
self
.
upscale_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps_upscale
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps_upscale
)),
total
=
num_inference_steps_upscale
):
...
...
@@ -825,7 +825,7 @@ class GlidePipeline(DiffusionPipeline):
noise_residual
,
pred_variance
=
torch
.
split
(
model_output
,
3
,
dim
=
1
)
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
upscale_
noise_
scheduler
.
step
(
pred_prev_image
=
self
.
upscale_scheduler
.
step
(
noise_residual
,
image
,
t
,
num_inference_steps_upscale
,
eta
,
use_clipped_residual
=
True
)
...
...
@@ -833,9 +833,7 @@ class GlidePipeline(DiffusionPipeline):
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
torch_device
)
variance
=
(
self
.
upscale_noise_scheduler
.
get_variance
(
t
,
num_inference_steps_upscale
).
sqrt
()
*
eta
*
noise
)
variance
=
self
.
upscale_scheduler
.
get_variance
(
t
,
num_inference_steps_upscale
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
...
...
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
f448360b
...
...
@@ -545,10 +545,10 @@ class LDMBertModel(LDMBertPreTrainedModel):
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
scheduler
):
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
...
...
@@ -581,7 +581,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)
num_trained_timesteps
=
self
.
noise_
scheduler
.
config
.
timesteps
num_trained_timesteps
=
self
.
scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
torch
.
randn
(
...
...
@@ -622,13 +622,13 @@ class LatentDiffusionPipeline(DiffusionPipeline):
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_
scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
pred_prev_image
=
self
.
scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_
scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
variance
=
self
.
scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
...
...
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
View file @
f448360b
...
...
@@ -6,10 +6,10 @@ from ...pipeline_utils import DiffusionPipeline
class
LatentDiffusionUncondPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
vqvae
,
unet
,
scheduler
):
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
...
...
@@ -28,44 +28,23 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
num_trained_timesteps
=
self
.
noise_scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
).
to
(
torch_device
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# 1. predict noise residual
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
pred_noise_t
=
self
.
unet
(
image
,
timesteps
)
if
isinstance
(
pred_noise_t
,
dict
):
pred_noise_t
=
pred_noise_t
[
"sample"
]
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
# 2. predict previous mean of image x_t-1
p
re
d_prev_image
=
self
.
noise_scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
for
t
in
tqdm
.
tqdm
(
self
.
scheduler
.
timesteps
):
re
sidual
=
self
.
unet
(
image
,
t
)
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image
=
self
.
scheduler
.
step
(
residual
,
t
,
image
,
eta
)[
"prev_sample"
]
# decode image with vae
image
=
self
.
vqvae
.
decode
(
image
)
return
image
return
{
"sample"
:
image
}
src/diffusers/pipelines/pndm/pipeline_pndm.py
View file @
f448360b
...
...
@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
class
PNDMPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
unet
,
scheduler
):
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
# For more information on the sampling method you can take a look at Algorithm 2 of
...
...
@@ -42,7 +42,7 @@ class PNDMPipeline(DiffusionPipeline):
)
image
=
image
.
to
(
torch_device
)
prk_time_steps
=
self
.
noise_
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
prk_time_steps
=
self
.
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
range
(
len
(
prk_time_steps
))):
t_orig
=
prk_time_steps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
...
...
@@ -50,9 +50,9 @@ class PNDMPipeline(DiffusionPipeline):
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
image
=
self
.
noise_
scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
image
=
self
.
scheduler
.
step_prk
(
residual
,
t
,
image
,
num_inference_steps
)
[
"prev_sample"
]
timesteps
=
self
.
noise_
scheduler
.
get_time_steps
(
num_inference_steps
)
timesteps
=
self
.
scheduler
.
get_time_steps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
range
(
len
(
timesteps
))):
t_orig
=
timesteps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
...
...
@@ -60,6 +60,6 @@ class PNDMPipeline(DiffusionPipeline):
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
image
=
self
.
noise_
scheduler
.
step_plms
(
residual
,
image
,
t
,
num_inference_steps
)
image
=
self
.
scheduler
.
step_plms
(
residual
,
t
,
image
,
num_inference_steps
)
[
"prev_sample"
]
return
image
src/diffusers/schedulers/scheduling_ddim.py
View file @
f448360b
...
...
@@ -16,8 +16,10 @@
# and https://github.com/hojonathanho/diffusion
import
math
from
typing
import
Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -84,14 +86,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
self
.
one
=
np
.
array
(
1.0
)
self
.
set_format
(
tensor_format
=
tensor_format
)
# setable values
self
.
num_inference_steps
=
None
self
.
timesteps
=
np
.
arange
(
0
,
self
.
config
.
timesteps
)[::
-
1
].
copy
()
def
get_variance
(
self
,
t
,
num_inference_steps
):
orig_t
=
self
.
config
.
timesteps
//
num_inference_steps
*
t
orig_prev_t
=
self
.
config
.
timesteps
//
num_inference_steps
*
(
t
-
1
)
if
t
>
0
else
-
1
self
.
tensor_format
=
tensor_format
self
.
set_format
(
tensor_format
=
tensor_format
)
alpha_prod_t
=
self
.
alphas_cumprod
[
orig_t
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
orig_prev_t
]
if
orig_prev_t
>=
0
else
self
.
one
def
_get_variance
(
self
,
timestep
,
prev_timestep
):
alpha_prod_t
=
self
.
alphas_cumprod
[
timestep
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
prev_timestep
]
if
prev_timestep
>=
0
else
self
.
one
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
...
...
@@ -99,7 +103,22 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return
variance
def
step
(
self
,
residual
,
sample
,
t
,
num_inference_steps
,
eta
,
use_clipped_residual
=
False
):
def
set_timesteps
(
self
,
num_inference_steps
):
self
.
num_inference_steps
=
num_inference_steps
self
.
timesteps
=
np
.
arange
(
0
,
self
.
config
.
timesteps
,
self
.
config
.
timesteps
//
self
.
num_inference_steps
)[
::
-
1
].
copy
()
self
.
set_format
(
tensor_format
=
self
.
tensor_format
)
def
step
(
self
,
residual
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
eta
,
use_clipped_residual
=
False
,
generator
=
None
,
):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
...
...
@@ -111,13 +130,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_sample_direction -> "direction pointingc to x_t"
# - pred_prev_sample -> "x_t-1"
# 1. get actual t and t-1
orig_t
=
self
.
config
.
timesteps
//
num_inference_steps
*
t
orig_prev_t
=
self
.
config
.
timesteps
//
num_inference_steps
*
(
t
-
1
)
if
t
>
0
else
-
1
# 1. get previous step value (=t-1)
prev_timestep
=
timestep
-
self
.
config
.
timesteps
//
self
.
num_inference_steps
# 2. compute alphas, betas
alpha_prod_t
=
self
.
alphas_cumprod
[
orig_t
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
orig_
prev_t
]
if
orig_prev_t
>=
0
else
self
.
one
alpha_prod_t
=
self
.
alphas_cumprod
[
timestep
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
prev_t
imestep
]
if
prev_timestep
>=
0
else
self
.
one
beta_prod_t
=
1
-
alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
...
...
@@ -130,7 +148,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance
=
self
.
get_variance
(
t
,
num_inference_
step
s
)
variance
=
self
.
_
get_variance
(
t
imestep
,
prev_time
step
)
std_dev_t
=
eta
*
variance
**
(
0.5
)
if
use_clipped_residual
:
...
...
@@ -141,9 +159,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
pred_sample_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
)
**
(
0.5
)
*
residual
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_sample
=
alpha_prod_t_prev
**
(
0.5
)
*
pred_original_sample
+
pred_sample_direction
prev_sample
=
alpha_prod_t_prev
**
(
0.5
)
*
pred_original_sample
+
pred_sample_direction
if
eta
>
0
:
device
=
residual
.
device
if
torch
.
is_tensor
(
residual
)
else
"cpu"
noise
=
torch
.
randn
(
residual
.
shape
,
generator
=
generator
).
to
(
device
)
variance
=
self
.
_get_variance
(
timestep
,
prev_timestep
)
**
(
0.5
)
*
eta
*
noise
if
not
torch
.
is_tensor
(
residual
):
variance
=
variance
.
numpy
()
prev_sample
=
prev_sample
+
variance
return
pre
d_
prev_sample
return
{
"
pre
v_sample"
:
prev_sample
}
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
f448360b
...
...
@@ -15,8 +15,10 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
from
typing
import
Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -112,7 +114,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
variance
def
step
(
self
,
residual
,
sample
,
t
,
predict_epsilon
=
True
):
def
step
(
self
,
residual
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
predict_epsilon
=
True
,
):
t
=
timestep
# 1. compute alphas, betas
alpha_prod_t
=
self
.
alphas_cumprod
[
t
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
t
-
1
]
if
t
>
0
else
self
.
one
...
...
@@ -139,7 +148,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample
=
pred_original_sample_coeff
*
pred_original_sample
+
current_sample_coeff
*
sample
return
pred_prev_sample
return
{
"prev_sample"
:
pred_prev_sample
}
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
f448360b
...
...
@@ -15,8 +15,10 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
from
typing
import
Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -126,7 +128,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
raise
ValueError
(
f
"mode
{
self
.
mode
}
does not exist."
)
def
step_prk
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
def
step_prk
(
self
,
residual
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
num_inference_steps
,
):
t
=
timestep
prk_time_steps
=
self
.
get_prk_time_steps
(
num_inference_steps
)
t_orig
=
prk_time_steps
[
t
//
4
*
4
]
...
...
@@ -147,9 +156,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# cur_sample should not be `None`
cur_sample
=
self
.
cur_sample
if
self
.
cur_sample
is
not
None
else
sample
return
self
.
get_prev_sample
(
cur_sample
,
t_orig
,
t_orig_prev
,
residual
)
return
{
"prev_sample"
:
self
.
get_prev_sample
(
cur_sample
,
t_orig
,
t_orig_prev
,
residual
)
}
def
step_plms
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
def
step_plms
(
self
,
residual
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
num_inference_steps
,
):
t
=
timestep
if
len
(
self
.
ets
)
<
3
:
raise
ValueError
(
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
...
...
@@ -166,7 +182,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
return
self
.
get_prev_sample
(
sample
,
t_orig
,
t_orig_prev
,
residual
)
return
{
"prev_sample"
:
self
.
get_prev_sample
(
sample
,
t_orig
,
t_orig_prev
,
residual
)
}
def
get_prev_sample
(
self
,
sample
,
t_orig
,
t_orig_prev
,
residual
):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
...
...
tests/test_modeling_utils.py
View file @
f448360b
...
...
@@ -165,7 +165,7 @@ class ModelTesterMixin:
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names
=
[
*
signature
.
parameters
.
keys
()]
expected_arg_names
=
[
"sample"
,
"step
_value
"
]
expected_arg_names
=
[
"sample"
,
"
time
step"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_config
(
self
):
...
...
@@ -248,7 +248,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
}
@
property
def
input_shape
(
self
):
...
...
@@ -323,7 +323,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
low_res
=
torch
.
randn
((
batch_size
,
3
)
+
low_res_size
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
,
"low_res"
:
low_res
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
,
"low_res"
:
low_res
}
@
property
def
input_shape
(
self
):
...
...
@@ -414,7 +414,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
emb
=
torch
.
randn
((
batch_size
,
seq_len
,
transformer_dim
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
,
"transformer_out"
:
emb
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
,
"transformer_out"
:
emb
}
@
property
def
input_shape
(
self
):
...
...
@@ -506,7 +506,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
}
@
property
def
input_shape
(
self
):
...
...
@@ -601,7 +601,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
}
@
property
def
input_shape
(
self
):
...
...
@@ -899,8 +899,8 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
ddpm
.
noise_
scheduler
.
num_timesteps
=
10
ddpm_from_hub
.
noise_
scheduler
.
num_timesteps
=
10
ddpm
.
scheduler
.
num_timesteps
=
10
ddpm_from_hub
.
scheduler
.
num_timesteps
=
10
generator
=
torch
.
manual_seed
(
0
)
...
...
@@ -915,10 +915,10 @@ class PipelineTesterMixin(unittest.TestCase):
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_
scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
...
...
@@ -936,13 +936,12 @@ class PipelineTesterMixin(unittest.TestCase):
model_id
=
"fusing/ddpm-lsun-bedroom-ema"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
image
=
ddpm
(
generator
=
generator
)
[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
...
@@ -957,12 +956,12 @@ class PipelineTesterMixin(unittest.TestCase):
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
...
@@ -977,9 +976,9 @@ class PipelineTesterMixin(unittest.TestCase):
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
image
=
pndm
(
generator
=
generator
)
...
...
@@ -1074,7 +1073,7 @@ class PipelineTesterMixin(unittest.TestCase):
ldm
=
LatentDiffusionUncondPipeline
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
,
ldm
=
True
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
)
[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
...
tests/test_scheduler.py
View file @
f448360b
...
...
@@ -68,6 +68,8 @@ class SchedulerCommonTest(unittest.TestCase):
def
check_over_configs
(
self
,
time_step
=
0
,
**
config
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
sample
=
self
.
dummy_sample
...
...
@@ -80,8 +82,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
...
...
@@ -89,6 +97,8 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
.
update
(
forward_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
...
...
@@ -101,14 +111,24 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
torch
.
manual_seed
(
0
)
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
torch
.
manual_seed
(
0
)
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_pretrained_save_pretrained
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
...
...
@@ -120,14 +140,22 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
sample
,
1
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
1
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output
=
scheduler
.
step
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_step_shape
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
...
...
@@ -135,8 +163,13 @@ class SchedulerCommonTest(unittest.TestCase):
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
output_0
=
scheduler
.
step
(
residual
,
sample
,
0
,
**
kwargs
)
output_1
=
scheduler
.
step
(
residual
,
sample
,
1
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output_0
=
scheduler
.
step
(
residual
,
0
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
...
...
@@ -144,6 +177,8 @@ class SchedulerCommonTest(unittest.TestCase):
def
test_pytorch_equal_numpy
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
...
...
@@ -156,8 +191,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_pt
=
scheduler_class
(
tensor_format
=
"pt"
,
**
scheduler_config
)
output
=
scheduler
.
step
(
residual
,
sample
,
1
,
**
kwargs
)
output_pt
=
scheduler_pt
.
step
(
residual_pt
,
sample_pt
,
1
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
scheduler_pt
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output
=
scheduler
.
step
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step
(
residual_pt
,
1
,
sample_pt
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
...
...
@@ -226,7 +267,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
residual
=
model
(
sample
,
t
)
# 2. predict previous mean of sample x_t-1
pred_prev_sample
=
scheduler
.
step
(
residual
,
sample
,
t
)
pred_prev_sample
=
scheduler
.
step
(
residual
,
t
,
sample
)[
"prev_sample"
]
if
t
>
0
:
noise
=
self
.
dummy_sample_deter
...
...
@@ -243,7 +284,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
class
DDIMSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
DDIMScheduler
,)
forward_default_kwargs
=
((
"num_inference_steps"
,
50
)
,
(
"eta"
,
0.0
)
)
forward_default_kwargs
=
(
(
"eta"
,
0.0
),
(
"num_inference_steps"
,
50
))
def
get_scheduler_config
(
self
,
**
kwargs
):
config
=
{
...
...
@@ -258,7 +299,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
return
config
def
test_timesteps
(
self
):
for
timesteps
in
[
1
,
5
,
1
00
,
1000
]:
for
timesteps
in
[
1
00
,
500
,
1000
]:
self
.
check_over_configs
(
timesteps
=
timesteps
)
def
test_betas
(
self
):
...
...
@@ -279,7 +320,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
def
test_inference_steps
(
self
):
for
t
,
num_inference_steps
in
zip
([
1
,
10
,
50
],
[
10
,
50
,
500
]):
self
.
check_over_forward
(
time_step
=
t
,
num_inference_steps
=
num_inference_steps
)
self
.
check_over_forward
(
num_inference_steps
=
num_inference_steps
)
def
test_eta
(
self
):
for
t
,
eta
in
zip
([
1
,
10
,
49
],
[
0.0
,
0.5
,
1.0
]):
...
...
@@ -290,43 +331,34 @@ class DDIMSchedulerTest(SchedulerCommonTest):
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
0
,
5
0
)
-
0.0
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
21
,
5
0
)
-
0.14771
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
49
,
5
0
)
-
0.32460
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
0
,
100
0
)
-
0.0
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
487
,
1000
)
-
0.00979
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
999
,
1000
)
-
0.02
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
0
,
0
)
-
0.0
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
420
,
40
0
)
-
0.14771
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
980
,
96
0
)
-
0.32460
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
0
,
0
)
-
0.0
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
487
,
486
)
-
0.00979
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
999
,
998
)
-
0.02
))
<
1e-5
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
num_inference_steps
,
eta
=
10
,
0.1
num_trained_timesteps
=
len
(
scheduler
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
num_inference_steps
,
eta
=
10
,
0.0
model
=
self
.
dummy_model
()
sample
=
self
.
dummy_sample_deter
for
t
in
reversed
(
range
(
num_inference_steps
)):
residual
=
model
(
sample
,
inference_step_times
[
t
])
pred_prev_sample
=
scheduler
.
step
(
residual
,
sample
,
t
,
num_inference_steps
,
eta
)
variance
=
0
if
eta
>
0
:
noise
=
self
.
dummy_sample_deter
variance
=
scheduler
.
get_variance
(
t
,
num_inference_steps
)
**
(
0.5
)
*
eta
*
noise
scheduler
.
set_timesteps
(
num_inference_steps
)
for
t
in
scheduler
.
timesteps
:
residual
=
model
(
sample
,
t
)
sample
=
pred_prev_sample
+
variance
sample
=
scheduler
.
step
(
residual
,
t
,
sample
,
eta
)[
"prev_sample"
]
result_sum
=
np
.
sum
(
np
.
abs
(
sample
))
result_mean
=
np
.
mean
(
np
.
abs
(
sample
))
assert
abs
(
result_sum
.
item
()
-
270.6214
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.
3524
)
<
1e-3
assert
abs
(
result_sum
.
item
()
-
172.0067
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.
223967
)
<
1e-3
class
PNDMSchedulerTest
(
SchedulerCommonTest
):
...
...
@@ -365,8 +397,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_plms_mode
()
output
=
scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)
[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)
[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
...
...
@@ -392,8 +424,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_plms_mode
()
output
=
scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)
[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)
[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
...
...
@@ -445,7 +477,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler
.
set_plms_mode
()
scheduler
.
step
(
self
.
dummy_sample
,
self
.
dummy_sample
,
1
,
50
)
scheduler
.
step
(
self
.
dummy_sample
,
1
,
self
.
dummy_sample
,
50
)
[
"prev_sample"
]
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
...
...
@@ -461,14 +493,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
t_orig
=
prk_time_steps
[
t
]
residual
=
model
(
sample
,
t_orig
)
sample
=
scheduler
.
step_prk
(
residual
,
sample
,
t
,
num_inference_steps
)
sample
=
scheduler
.
step_prk
(
residual
,
t
,
sample
,
num_inference_steps
)
[
"prev_sample"
]
timesteps
=
scheduler
.
get_time_steps
(
num_inference_steps
)
for
t
in
range
(
len
(
timesteps
)):
t_orig
=
timesteps
[
t
]
residual
=
model
(
sample
,
t_orig
)
sample
=
scheduler
.
step_plms
(
residual
,
sample
,
t
,
num_inference_steps
)
sample
=
scheduler
.
step_plms
(
residual
,
t
,
sample
,
num_inference_steps
)
[
"prev_sample"
]
result_sum
=
np
.
sum
(
np
.
abs
(
sample
))
result_mean
=
np
.
mean
(
np
.
abs
(
sample
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment