Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f448360b
Unverified
Commit
f448360b
authored
Jul 15, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 15, 2022
Browse files
Finish scheduler API (#91)
* finish * up
parent
97e1e3ba
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
233 additions
and
188 deletions
+233
-188
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+4
-4
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+2
-2
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+5
-3
src/diffusers/pipelines/ddim/pipeline_ddim.py
src/diffusers/pipelines/ddim/pipeline_ddim.py
+11
-29
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+6
-6
src/diffusers/pipelines/glide/pipeline_glide.py
src/diffusers/pipelines/glide/pipeline_glide.py
+11
-13
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+6
-6
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
...tent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+12
-33
src/diffusers/pipelines/pndm/pipeline_pndm.py
src/diffusers/pipelines/pndm/pipeline_pndm.py
+7
-7
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+43
-15
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+11
-2
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+20
-4
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+20
-21
tests/test_scheduler.py
tests/test_scheduler.py
+75
-43
No files found.
src/diffusers/models/unet_glide.py
View file @
f448360b
...
...
@@ -438,8 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel):
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
def
forward
(
self
,
sample
,
step
_value
,
transformer_out
=
None
):
timesteps
=
step
_value
def
forward
(
self
,
sample
,
time
step
,
transformer_out
=
None
):
timesteps
=
time
step
x
=
sample
hs
=
[]
emb
=
self
.
time_embed
(
...
...
@@ -530,8 +530,8 @@ class GlideSuperResUNetModel(GlideUNetModel):
resblock_updown
=
resblock_updown
,
)
def
forward
(
self
,
sample
,
step
_value
,
low_res
=
None
):
timesteps
=
step
_value
def
forward
(
self
,
sample
,
time
step
,
low_res
=
None
):
timesteps
=
time
step
x
=
sample
_
,
_
,
new_height
,
new_width
=
x
.
shape
upsampled
=
F
.
interpolate
(
low_res
,
(
new_height
,
new_width
),
mode
=
"bilinear"
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
f448360b
...
...
@@ -323,8 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
def
forward
(
self
,
sample
,
step
_value
,
sigmas
=
None
):
timesteps
=
step
_value
def
forward
(
self
,
sample
,
time
step
,
sigmas
=
None
):
timesteps
=
time
step
x
=
sample
# timestep/noise_level embedding; only for continuous training
modules
=
self
.
all_modules
...
...
src/diffusers/models/unet_unconditional.py
View file @
f448360b
...
...
@@ -254,7 +254,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# ======================================
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
step
_value
:
Union
[
torch
.
Tensor
,
float
,
int
]
self
,
sample
:
torch
.
FloatTensor
,
time
step
:
Union
[
torch
.
Tensor
,
float
,
int
]
)
->
Dict
[
str
,
torch
.
FloatTensor
]:
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
...
...
@@ -263,10 +263,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
set_weights
()
# ======================================
# 1. time step embeddings
timesteps
=
step
_value
# 1. time step embeddings
-> make correct tensor
timesteps
=
time
step
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
elif
torch
.
is_tensor
(
timesteps
)
and
len
(
timesteps
.
shape
)
==
0
:
timesteps
=
timesteps
[
None
].
to
(
sample
.
device
)
t_emb
=
get_timestep_embedding
(
timesteps
,
...
...
src/diffusers/pipelines/ddim/pipeline_ddim.py
View file @
f448360b
...
...
@@ -22,19 +22,16 @@ from ...pipeline_utils import DiffusionPipeline
class
DDIMPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
unet
,
scheduler
):
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
num_trained_timesteps
=
self
.
noise_scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
self
.
unet
.
to
(
torch_device
)
# Sample gaussian noise to begin loop
...
...
@@ -44,34 +41,19 @@ class DDIMPipeline(DiffusionPipeline):
)
image
=
image
.
to
(
torch_device
)
#
See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
#
set step values
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
for
t
in
tqdm
.
tqdm
(
self
.
scheduler
.
timesteps
):
# 1. predict noise residual
with
torch
.
no_grad
():
residual
=
self
.
unet
(
image
,
inference_step_times
[
t
]
)
residual
=
self
.
unet
(
image
,
t
)
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
,
num_inference_steps
,
eta
)
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image
=
self
.
scheduler
.
step
(
residual
,
t
,
image
,
eta
)[
"prev_sample"
]
return
image
return
{
"sample"
:
image
}
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
f448360b
...
...
@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
class
DDPMPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
unet
,
scheduler
):
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
):
if
torch_device
is
None
:
...
...
@@ -40,7 +40,7 @@ class DDPMPipeline(DiffusionPipeline):
)
image
=
image
.
to
(
torch_device
)
num_prediction_steps
=
len
(
self
.
noise_
scheduler
)
num_prediction_steps
=
len
(
self
.
scheduler
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_prediction_steps
)),
total
=
num_prediction_steps
):
# 1. predict noise residual
with
torch
.
no_grad
():
...
...
@@ -50,13 +50,13 @@ class DDPMPipeline(DiffusionPipeline):
residual
=
residual
[
"sample"
]
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_
scheduler
.
step
(
residual
,
image
,
t
)
pred_prev_image
=
self
.
scheduler
.
step
(
residual
,
t
,
image
)[
"prev_sample"
]
# 3. optionally sample variance
variance
=
0
if
t
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_
scheduler
.
get_variance
(
t
).
sqrt
()
*
noise
variance
=
self
.
scheduler
.
get_variance
(
t
).
sqrt
()
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
...
...
src/diffusers/pipelines/glide/pipeline_glide.py
View file @
f448360b
...
...
@@ -713,20 +713,20 @@ class GlidePipeline(DiffusionPipeline):
def
__init__
(
self
,
text_unet
:
GlideTextToImageUNetModel
,
text_
noise_
scheduler
:
DDPMScheduler
,
text_scheduler
:
DDPMScheduler
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
GPT2Tokenizer
,
upscale_unet
:
GlideSuperResUNetModel
,
upscale_
noise_
scheduler
:
DDIMScheduler
,
upscale_scheduler
:
DDIMScheduler
,
):
super
().
__init__
()
self
.
register_modules
(
text_unet
=
text_unet
,
text_
noise_
scheduler
=
text_
noise_
scheduler
,
text_scheduler
=
text_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
upscale_unet
=
upscale_unet
,
upscale_
noise_
scheduler
=
upscale_
noise_
scheduler
,
upscale_scheduler
=
upscale_scheduler
,
)
@
torch
.
no_grad
()
...
...
@@ -777,20 +777,20 @@ class GlidePipeline(DiffusionPipeline):
transformer_out
=
self
.
text_encoder
(
input_ids
,
attention_mask
).
last_hidden_state
# 3. Run the text2image generation step
num_prediction_steps
=
len
(
self
.
text_
noise_
scheduler
)
num_prediction_steps
=
len
(
self
.
text_scheduler
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_prediction_steps
)),
total
=
num_prediction_steps
):
with
torch
.
no_grad
():
time_input
=
torch
.
tensor
([
t
]
*
image
.
shape
[
0
],
device
=
torch_device
)
model_output
=
text_model_fn
(
image
,
time_input
,
transformer_out
)
noise_residual
,
model_var_values
=
torch
.
split
(
model_output
,
3
,
dim
=
1
)
min_log
=
self
.
text_
noise_
scheduler
.
get_variance
(
t
,
"fixed_small_log"
)
max_log
=
self
.
text_
noise_
scheduler
.
get_variance
(
t
,
"fixed_large_log"
)
min_log
=
self
.
text_scheduler
.
get_variance
(
t
,
"fixed_small_log"
)
max_log
=
self
.
text_scheduler
.
get_variance
(
t
,
"fixed_large_log"
)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac
=
(
model_var_values
+
1
)
/
2
model_log_variance
=
frac
*
max_log
+
(
1
-
frac
)
*
min_log
pred_prev_image
=
self
.
text_
noise_
scheduler
.
step
(
noise_residual
,
image
,
t
)
pred_prev_image
=
self
.
text_scheduler
.
step
(
noise_residual
,
image
,
t
)
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
torch_device
)
variance
=
torch
.
exp
(
0.5
*
model_log_variance
)
*
noise
...
...
@@ -814,7 +814,7 @@ class GlidePipeline(DiffusionPipeline):
).
to
(
torch_device
)
image
=
image
*
upsample_temp
num_trained_timesteps
=
self
.
upscale_
noise_
scheduler
.
timesteps
num_trained_timesteps
=
self
.
upscale_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps_upscale
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps_upscale
)),
total
=
num_inference_steps_upscale
):
...
...
@@ -825,7 +825,7 @@ class GlidePipeline(DiffusionPipeline):
noise_residual
,
pred_variance
=
torch
.
split
(
model_output
,
3
,
dim
=
1
)
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
upscale_
noise_
scheduler
.
step
(
pred_prev_image
=
self
.
upscale_scheduler
.
step
(
noise_residual
,
image
,
t
,
num_inference_steps_upscale
,
eta
,
use_clipped_residual
=
True
)
...
...
@@ -833,9 +833,7 @@ class GlidePipeline(DiffusionPipeline):
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
torch_device
)
variance
=
(
self
.
upscale_noise_scheduler
.
get_variance
(
t
,
num_inference_steps_upscale
).
sqrt
()
*
eta
*
noise
)
variance
=
self
.
upscale_scheduler
.
get_variance
(
t
,
num_inference_steps_upscale
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
...
...
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
f448360b
...
...
@@ -545,10 +545,10 @@ class LDMBertModel(LDMBertPreTrainedModel):
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
scheduler
):
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
...
...
@@ -581,7 +581,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)
num_trained_timesteps
=
self
.
noise_
scheduler
.
config
.
timesteps
num_trained_timesteps
=
self
.
scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
torch
.
randn
(
...
...
@@ -622,13 +622,13 @@ class LatentDiffusionPipeline(DiffusionPipeline):
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_
scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
pred_prev_image
=
self
.
scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_
scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
variance
=
self
.
scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
...
...
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
View file @
f448360b
...
...
@@ -6,10 +6,10 @@ from ...pipeline_utils import DiffusionPipeline
class
LatentDiffusionUncondPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
vqvae
,
unet
,
scheduler
):
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
def
__call__
(
...
...
@@ -28,44 +28,23 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
num_trained_timesteps
=
self
.
noise_scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
).
to
(
torch_device
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# 1. predict noise residual
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
pred_noise_t
=
self
.
unet
(
image
,
timesteps
)
if
isinstance
(
pred_noise_t
,
dict
):
pred_noise_t
=
pred_noise_t
[
"sample"
]
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
# 2. predict previous mean of image x_t-1
p
re
d_prev_image
=
self
.
noise_scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
for
t
in
tqdm
.
tqdm
(
self
.
scheduler
.
timesteps
):
re
sidual
=
self
.
unet
(
image
,
t
)
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image
=
self
.
scheduler
.
step
(
residual
,
t
,
image
,
eta
)[
"prev_sample"
]
# decode image with vae
image
=
self
.
vqvae
.
decode
(
image
)
return
image
return
{
"sample"
:
image
}
src/diffusers/pipelines/pndm/pipeline_pndm.py
View file @
f448360b
...
...
@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
class
PNDMPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
unet
,
scheduler
):
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
# For more information on the sampling method you can take a look at Algorithm 2 of
...
...
@@ -42,7 +42,7 @@ class PNDMPipeline(DiffusionPipeline):
)
image
=
image
.
to
(
torch_device
)
prk_time_steps
=
self
.
noise_
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
prk_time_steps
=
self
.
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
range
(
len
(
prk_time_steps
))):
t_orig
=
prk_time_steps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
...
...
@@ -50,9 +50,9 @@ class PNDMPipeline(DiffusionPipeline):
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
image
=
self
.
noise_
scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
image
=
self
.
scheduler
.
step_prk
(
residual
,
t
,
image
,
num_inference_steps
)
[
"prev_sample"
]
timesteps
=
self
.
noise_
scheduler
.
get_time_steps
(
num_inference_steps
)
timesteps
=
self
.
scheduler
.
get_time_steps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
range
(
len
(
timesteps
))):
t_orig
=
timesteps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
...
...
@@ -60,6 +60,6 @@ class PNDMPipeline(DiffusionPipeline):
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
image
=
self
.
noise_
scheduler
.
step_plms
(
residual
,
image
,
t
,
num_inference_steps
)
image
=
self
.
scheduler
.
step_plms
(
residual
,
t
,
image
,
num_inference_steps
)
[
"prev_sample"
]
return
image
src/diffusers/schedulers/scheduling_ddim.py
View file @
f448360b
...
...
@@ -16,8 +16,10 @@
# and https://github.com/hojonathanho/diffusion
import
math
from
typing
import
Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -84,14 +86,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
self
.
one
=
np
.
array
(
1.0
)
self
.
set_format
(
tensor_format
=
tensor_format
)
# setable values
self
.
num_inference_steps
=
None
self
.
timesteps
=
np
.
arange
(
0
,
self
.
config
.
timesteps
)[::
-
1
].
copy
()
def
get_variance
(
self
,
t
,
num_inference_steps
):
orig_t
=
self
.
config
.
timesteps
//
num_inference_steps
*
t
orig_prev_t
=
self
.
config
.
timesteps
//
num_inference_steps
*
(
t
-
1
)
if
t
>
0
else
-
1
self
.
tensor_format
=
tensor_format
self
.
set_format
(
tensor_format
=
tensor_format
)
alpha_prod_t
=
self
.
alphas_cumprod
[
orig_t
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
orig_prev_t
]
if
orig_prev_t
>=
0
else
self
.
one
def
_get_variance
(
self
,
timestep
,
prev_timestep
):
alpha_prod_t
=
self
.
alphas_cumprod
[
timestep
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
prev_timestep
]
if
prev_timestep
>=
0
else
self
.
one
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
...
...
@@ -99,7 +103,22 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return
variance
def
step
(
self
,
residual
,
sample
,
t
,
num_inference_steps
,
eta
,
use_clipped_residual
=
False
):
def
set_timesteps
(
self
,
num_inference_steps
):
self
.
num_inference_steps
=
num_inference_steps
self
.
timesteps
=
np
.
arange
(
0
,
self
.
config
.
timesteps
,
self
.
config
.
timesteps
//
self
.
num_inference_steps
)[
::
-
1
].
copy
()
self
.
set_format
(
tensor_format
=
self
.
tensor_format
)
def
step
(
self
,
residual
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
eta
,
use_clipped_residual
=
False
,
generator
=
None
,
):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
...
...
@@ -111,13 +130,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_sample_direction -> "direction pointingc to x_t"
# - pred_prev_sample -> "x_t-1"
# 1. get actual t and t-1
orig_t
=
self
.
config
.
timesteps
//
num_inference_steps
*
t
orig_prev_t
=
self
.
config
.
timesteps
//
num_inference_steps
*
(
t
-
1
)
if
t
>
0
else
-
1
# 1. get previous step value (=t-1)
prev_timestep
=
timestep
-
self
.
config
.
timesteps
//
self
.
num_inference_steps
# 2. compute alphas, betas
alpha_prod_t
=
self
.
alphas_cumprod
[
orig_t
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
orig_
prev_t
]
if
orig_prev_t
>=
0
else
self
.
one
alpha_prod_t
=
self
.
alphas_cumprod
[
timestep
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
prev_t
imestep
]
if
prev_timestep
>=
0
else
self
.
one
beta_prod_t
=
1
-
alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
...
...
@@ -130,7 +148,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance
=
self
.
get_variance
(
t
,
num_inference_
step
s
)
variance
=
self
.
_
get_variance
(
t
imestep
,
prev_time
step
)
std_dev_t
=
eta
*
variance
**
(
0.5
)
if
use_clipped_residual
:
...
...
@@ -141,9 +159,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
pred_sample_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
)
**
(
0.5
)
*
residual
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_sample
=
alpha_prod_t_prev
**
(
0.5
)
*
pred_original_sample
+
pred_sample_direction
prev_sample
=
alpha_prod_t_prev
**
(
0.5
)
*
pred_original_sample
+
pred_sample_direction
if
eta
>
0
:
device
=
residual
.
device
if
torch
.
is_tensor
(
residual
)
else
"cpu"
noise
=
torch
.
randn
(
residual
.
shape
,
generator
=
generator
).
to
(
device
)
variance
=
self
.
_get_variance
(
timestep
,
prev_timestep
)
**
(
0.5
)
*
eta
*
noise
if
not
torch
.
is_tensor
(
residual
):
variance
=
variance
.
numpy
()
prev_sample
=
prev_sample
+
variance
return
pre
d_
prev_sample
return
{
"
pre
v_sample"
:
prev_sample
}
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
f448360b
...
...
@@ -15,8 +15,10 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
from
typing
import
Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -112,7 +114,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
variance
def
step
(
self
,
residual
,
sample
,
t
,
predict_epsilon
=
True
):
def
step
(
self
,
residual
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
predict_epsilon
=
True
,
):
t
=
timestep
# 1. compute alphas, betas
alpha_prod_t
=
self
.
alphas_cumprod
[
t
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
t
-
1
]
if
t
>
0
else
self
.
one
...
...
@@ -139,7 +148,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample
=
pred_original_sample_coeff
*
pred_original_sample
+
current_sample_coeff
*
sample
return
pred_prev_sample
return
{
"prev_sample"
:
pred_prev_sample
}
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
f448360b
...
...
@@ -15,8 +15,10 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
from
typing
import
Union
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
...
...
@@ -126,7 +128,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
raise
ValueError
(
f
"mode
{
self
.
mode
}
does not exist."
)
def
step_prk
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
def
step_prk
(
self
,
residual
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
num_inference_steps
,
):
t
=
timestep
prk_time_steps
=
self
.
get_prk_time_steps
(
num_inference_steps
)
t_orig
=
prk_time_steps
[
t
//
4
*
4
]
...
...
@@ -147,9 +156,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# cur_sample should not be `None`
cur_sample
=
self
.
cur_sample
if
self
.
cur_sample
is
not
None
else
sample
return
self
.
get_prev_sample
(
cur_sample
,
t_orig
,
t_orig_prev
,
residual
)
return
{
"prev_sample"
:
self
.
get_prev_sample
(
cur_sample
,
t_orig
,
t_orig_prev
,
residual
)
}
def
step_plms
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
def
step_plms
(
self
,
residual
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
num_inference_steps
,
):
t
=
timestep
if
len
(
self
.
ets
)
<
3
:
raise
ValueError
(
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
...
...
@@ -166,7 +182,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
return
self
.
get_prev_sample
(
sample
,
t_orig
,
t_orig_prev
,
residual
)
return
{
"prev_sample"
:
self
.
get_prev_sample
(
sample
,
t_orig
,
t_orig_prev
,
residual
)
}
def
get_prev_sample
(
self
,
sample
,
t_orig
,
t_orig_prev
,
residual
):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
...
...
tests/test_modeling_utils.py
View file @
f448360b
...
...
@@ -165,7 +165,7 @@ class ModelTesterMixin:
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names
=
[
*
signature
.
parameters
.
keys
()]
expected_arg_names
=
[
"sample"
,
"step
_value
"
]
expected_arg_names
=
[
"sample"
,
"
time
step"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_config
(
self
):
...
...
@@ -248,7 +248,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
}
@
property
def
input_shape
(
self
):
...
...
@@ -323,7 +323,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
low_res
=
torch
.
randn
((
batch_size
,
3
)
+
low_res_size
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
,
"low_res"
:
low_res
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
,
"low_res"
:
low_res
}
@
property
def
input_shape
(
self
):
...
...
@@ -414,7 +414,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
emb
=
torch
.
randn
((
batch_size
,
seq_len
,
transformer_dim
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
,
"transformer_out"
:
emb
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
,
"transformer_out"
:
emb
}
@
property
def
input_shape
(
self
):
...
...
@@ -506,7 +506,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
}
@
property
def
input_shape
(
self
):
...
...
@@ -601,7 +601,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
}
@
property
def
input_shape
(
self
):
...
...
@@ -899,8 +899,8 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
ddpm
.
noise_
scheduler
.
num_timesteps
=
10
ddpm_from_hub
.
noise_
scheduler
.
num_timesteps
=
10
ddpm
.
scheduler
.
num_timesteps
=
10
ddpm_from_hub
.
scheduler
.
num_timesteps
=
10
generator
=
torch
.
manual_seed
(
0
)
...
...
@@ -915,10 +915,10 @@ class PipelineTesterMixin(unittest.TestCase):
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_
scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
...
...
@@ -936,13 +936,12 @@ class PipelineTesterMixin(unittest.TestCase):
model_id
=
"fusing/ddpm-lsun-bedroom-ema"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
image
=
ddpm
(
generator
=
generator
)
[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
...
@@ -957,12 +956,12 @@ class PipelineTesterMixin(unittest.TestCase):
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
...
@@ -977,9 +976,9 @@ class PipelineTesterMixin(unittest.TestCase):
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
image
=
pndm
(
generator
=
generator
)
...
...
@@ -1074,7 +1073,7 @@ class PipelineTesterMixin(unittest.TestCase):
ldm
=
LatentDiffusionUncondPipeline
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
,
ldm
=
True
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
)
[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
...
tests/test_scheduler.py
View file @
f448360b
...
...
@@ -68,6 +68,8 @@ class SchedulerCommonTest(unittest.TestCase):
def
check_over_configs
(
self
,
time_step
=
0
,
**
config
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
sample
=
self
.
dummy_sample
...
...
@@ -80,8 +82,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
...
...
@@ -89,6 +97,8 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
.
update
(
forward_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
...
...
@@ -101,14 +111,24 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
torch
.
manual_seed
(
0
)
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
torch
.
manual_seed
(
0
)
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_pretrained_save_pretrained
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
...
...
@@ -120,14 +140,22 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
sample
,
1
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
1
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output
=
scheduler
.
step
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_step_shape
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
...
...
@@ -135,8 +163,13 @@ class SchedulerCommonTest(unittest.TestCase):
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
output_0
=
scheduler
.
step
(
residual
,
sample
,
0
,
**
kwargs
)
output_1
=
scheduler
.
step
(
residual
,
sample
,
1
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output_0
=
scheduler
.
step
(
residual
,
0
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
...
...
@@ -144,6 +177,8 @@ class SchedulerCommonTest(unittest.TestCase):
def
test_pytorch_equal_numpy
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
...
...
@@ -156,8 +191,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_pt
=
scheduler_class
(
tensor_format
=
"pt"
,
**
scheduler_config
)
output
=
scheduler
.
step
(
residual
,
sample
,
1
,
**
kwargs
)
output_pt
=
scheduler_pt
.
step
(
residual_pt
,
sample_pt
,
1
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
scheduler_pt
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output
=
scheduler
.
step
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step
(
residual_pt
,
1
,
sample_pt
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
...
...
@@ -226,7 +267,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
residual
=
model
(
sample
,
t
)
# 2. predict previous mean of sample x_t-1
pred_prev_sample
=
scheduler
.
step
(
residual
,
sample
,
t
)
pred_prev_sample
=
scheduler
.
step
(
residual
,
t
,
sample
)[
"prev_sample"
]
if
t
>
0
:
noise
=
self
.
dummy_sample_deter
...
...
@@ -243,7 +284,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
class
DDIMSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
DDIMScheduler
,)
forward_default_kwargs
=
((
"num_inference_steps"
,
50
)
,
(
"eta"
,
0.0
)
)
forward_default_kwargs
=
(
(
"eta"
,
0.0
),
(
"num_inference_steps"
,
50
))
def
get_scheduler_config
(
self
,
**
kwargs
):
config
=
{
...
...
@@ -258,7 +299,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
return
config
def
test_timesteps
(
self
):
for
timesteps
in
[
1
,
5
,
1
00
,
1000
]:
for
timesteps
in
[
1
00
,
500
,
1000
]:
self
.
check_over_configs
(
timesteps
=
timesteps
)
def
test_betas
(
self
):
...
...
@@ -279,7 +320,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
def
test_inference_steps
(
self
):
for
t
,
num_inference_steps
in
zip
([
1
,
10
,
50
],
[
10
,
50
,
500
]):
self
.
check_over_forward
(
time_step
=
t
,
num_inference_steps
=
num_inference_steps
)
self
.
check_over_forward
(
num_inference_steps
=
num_inference_steps
)
def
test_eta
(
self
):
for
t
,
eta
in
zip
([
1
,
10
,
49
],
[
0.0
,
0.5
,
1.0
]):
...
...
@@ -290,43 +331,34 @@ class DDIMSchedulerTest(SchedulerCommonTest):
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
0
,
5
0
)
-
0.0
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
21
,
5
0
)
-
0.14771
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
49
,
5
0
)
-
0.32460
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
0
,
100
0
)
-
0.0
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
487
,
1000
)
-
0.00979
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
999
,
1000
)
-
0.02
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
0
,
0
)
-
0.0
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
420
,
40
0
)
-
0.14771
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
980
,
96
0
)
-
0.32460
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
0
,
0
)
-
0.0
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
487
,
486
)
-
0.00979
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
999
,
998
)
-
0.02
))
<
1e-5
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
num_inference_steps
,
eta
=
10
,
0.1
num_trained_timesteps
=
len
(
scheduler
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
num_inference_steps
,
eta
=
10
,
0.0
model
=
self
.
dummy_model
()
sample
=
self
.
dummy_sample_deter
for
t
in
reversed
(
range
(
num_inference_steps
)):
residual
=
model
(
sample
,
inference_step_times
[
t
])
pred_prev_sample
=
scheduler
.
step
(
residual
,
sample
,
t
,
num_inference_steps
,
eta
)
variance
=
0
if
eta
>
0
:
noise
=
self
.
dummy_sample_deter
variance
=
scheduler
.
get_variance
(
t
,
num_inference_steps
)
**
(
0.5
)
*
eta
*
noise
scheduler
.
set_timesteps
(
num_inference_steps
)
for
t
in
scheduler
.
timesteps
:
residual
=
model
(
sample
,
t
)
sample
=
pred_prev_sample
+
variance
sample
=
scheduler
.
step
(
residual
,
t
,
sample
,
eta
)[
"prev_sample"
]
result_sum
=
np
.
sum
(
np
.
abs
(
sample
))
result_mean
=
np
.
mean
(
np
.
abs
(
sample
))
assert
abs
(
result_sum
.
item
()
-
270.6214
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.
3524
)
<
1e-3
assert
abs
(
result_sum
.
item
()
-
172.0067
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.
223967
)
<
1e-3
class
PNDMSchedulerTest
(
SchedulerCommonTest
):
...
...
@@ -365,8 +397,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_plms_mode
()
output
=
scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)
[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)
[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
...
...
@@ -392,8 +424,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_plms_mode
()
output
=
scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)
[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)
[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
...
...
@@ -445,7 +477,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler
.
set_plms_mode
()
scheduler
.
step
(
self
.
dummy_sample
,
self
.
dummy_sample
,
1
,
50
)
scheduler
.
step
(
self
.
dummy_sample
,
1
,
self
.
dummy_sample
,
50
)
[
"prev_sample"
]
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
...
...
@@ -461,14 +493,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
t_orig
=
prk_time_steps
[
t
]
residual
=
model
(
sample
,
t_orig
)
sample
=
scheduler
.
step_prk
(
residual
,
sample
,
t
,
num_inference_steps
)
sample
=
scheduler
.
step_prk
(
residual
,
t
,
sample
,
num_inference_steps
)
[
"prev_sample"
]
timesteps
=
scheduler
.
get_time_steps
(
num_inference_steps
)
for
t
in
range
(
len
(
timesteps
)):
t_orig
=
timesteps
[
t
]
residual
=
model
(
sample
,
t_orig
)
sample
=
scheduler
.
step_plms
(
residual
,
sample
,
t
,
num_inference_steps
)
sample
=
scheduler
.
step_plms
(
residual
,
t
,
sample
,
num_inference_steps
)
[
"prev_sample"
]
result_sum
=
np
.
sum
(
np
.
abs
(
sample
))
result_mean
=
np
.
mean
(
np
.
abs
(
sample
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment