Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
2d97544d
Commit
2d97544d
authored
Jun 12, 2022
by
Patrick von Platen
Browse files
add more tests schedulers
parent
bda825f9
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
272 additions
and
40 deletions
+272
-40
src/diffusers/pipelines/old/latent_diffusion/modeling_latent_diffusion.py
...pelines/old/latent_diffusion/modeling_latent_diffusion.py
+1
-1
src/diffusers/pipelines/pipeline_ddim.py
src/diffusers/pipelines/pipeline_ddim.py
+3
-3
src/diffusers/pipelines/pipeline_latent_diffusion.py
src/diffusers/pipelines/pipeline_latent_diffusion.py
+1
-1
src/diffusers/schedulers/classifier_free_guidance.py
src/diffusers/schedulers/classifier_free_guidance.py
+2
-2
src/diffusers/schedulers/ddim.py
src/diffusers/schedulers/ddim.py
+4
-4
src/diffusers/schedulers/gaussian_ddpm.py
src/diffusers/schedulers/gaussian_ddpm.py
+3
-3
src/diffusers/schedulers/glide_ddim.py
src/diffusers/schedulers/glide_ddim.py
+3
-3
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+4
-1
tests/test_scheduler.py
tests/test_scheduler.py
+251
-22
No files found.
src/diffusers/pipelines/old/latent_diffusion/modeling_latent_diffusion.py
View file @
2d97544d
...
@@ -33,7 +33,7 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -33,7 +33,7 @@ class LatentDiffusion(DiffusionPipeline):
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)[
0
]
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)[
0
]
num_trained_timesteps
=
self
.
noise_scheduler
.
num_
timesteps
num_trained_timesteps
=
self
.
noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
self
.
noise_scheduler
.
sample_noise
(
image
=
self
.
noise_scheduler
.
sample_noise
(
...
...
src/diffusers/pipelines/pipeline_ddim.py
View file @
2d97544d
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
import
torch
import
torch
import
tqdm
import
tqdm
from
..
import
DiffusionPipeline
from
..
pipeline_utils
import
DiffusionPipeline
class
DDIM
(
DiffusionPipeline
):
class
DDIM
(
DiffusionPipeline
):
...
@@ -30,7 +30,7 @@ class DDIM(DiffusionPipeline):
...
@@ -30,7 +30,7 @@ class DDIM(DiffusionPipeline):
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
num_trained_timesteps
=
self
.
noise_scheduler
.
num_
timesteps
num_trained_timesteps
=
self
.
noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
...
@@ -64,7 +64,7 @@ class DDIM(DiffusionPipeline):
...
@@ -64,7 +64,7 @@ class DDIM(DiffusionPipeline):
variance
=
0
variance
=
0
if
eta
>
0
:
if
eta
>
0
:
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
image
.
device
,
generator
=
generator
)
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
image
.
device
,
generator
=
generator
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
).
sqrt
()
*
eta
*
noise
variance
=
self
.
noise_scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
image
=
pred_prev_image
+
variance
...
...
src/diffusers/pipelines/pipeline_latent_diffusion.py
View file @
2d97544d
...
@@ -883,7 +883,7 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -883,7 +883,7 @@ class LatentDiffusion(DiffusionPipeline):
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)[
0
]
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)[
0
]
num_trained_timesteps
=
self
.
noise_scheduler
.
num_
timesteps
num_trained_timesteps
=
self
.
noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
self
.
noise_scheduler
.
sample_noise
(
image
=
self
.
noise_scheduler
.
sample_noise
(
...
...
src/diffusers/schedulers/classifier_free_guidance.py
View file @
2d97544d
...
@@ -61,7 +61,7 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
...
@@ -61,7 +61,7 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
timesteps
=
timesteps
,
timesteps
=
timesteps
,
beta_schedule
=
beta_schedule
,
beta_schedule
=
beta_schedule
,
)
)
self
.
num_
timesteps
=
int
(
timesteps
)
self
.
timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"squaredcos_cap_v2"
:
if
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
# GLIDE cosine schedule
...
@@ -94,4 +94,4 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
...
@@ -94,4 +94,4 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
num_
timesteps
return
self
.
timesteps
src/diffusers/schedulers/ddim.py
View file @
2d97544d
...
@@ -42,7 +42,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
...
@@ -42,7 +42,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
beta_end
=
beta_end
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
beta_schedule
=
beta_schedule
,
)
)
self
.
num_
timesteps
=
int
(
timesteps
)
self
.
timesteps
=
int
(
timesteps
)
self
.
clip_image
=
clip_predicted_image
self
.
clip_image
=
clip_predicted_image
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
...
@@ -90,7 +90,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
...
@@ -90,7 +90,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
def
get_orig_t
(
self
,
t
,
num_inference_steps
):
def
get_orig_t
(
self
,
t
,
num_inference_steps
):
if
t
<
0
:
if
t
<
0
:
return
-
1
return
-
1
return
self
.
num_
timesteps
//
num_inference_steps
*
t
return
self
.
timesteps
//
num_inference_steps
*
t
def
get_variance
(
self
,
t
,
num_inference_steps
):
def
get_variance
(
self
,
t
,
num_inference_steps
):
orig_t
=
self
.
get_orig_t
(
t
,
num_inference_steps
)
orig_t
=
self
.
get_orig_t
(
t
,
num_inference_steps
)
...
@@ -105,7 +105,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
...
@@ -105,7 +105,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
return
variance
return
variance
def
step
(
self
,
residual
,
image
,
t
,
num_inference_steps
,
eta
,
output_pred_x_0
=
False
):
def
step
(
self
,
residual
,
image
,
t
,
num_inference_steps
,
eta
):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Ideally, read DDIM paper in-detail understanding
...
@@ -152,4 +152,4 @@ class DDIMScheduler(nn.Module, ConfigMixin):
...
@@ -152,4 +152,4 @@ class DDIMScheduler(nn.Module, ConfigMixin):
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
num_
timesteps
return
self
.
timesteps
src/diffusers/schedulers/gaussian_ddpm.py
View file @
2d97544d
...
@@ -44,7 +44,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
...
@@ -44,7 +44,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
variance_type
=
variance_type
,
variance_type
=
variance_type
,
clip_predicted_image
=
clip_predicted_image
,
clip_predicted_image
=
clip_predicted_image
,
)
)
self
.
num_
timesteps
=
int
(
timesteps
)
self
.
timesteps
=
int
(
timesteps
)
self
.
clip_image
=
clip_predicted_image
self
.
clip_image
=
clip_predicted_image
self
.
variance_type
=
variance_type
self
.
variance_type
=
variance_type
...
@@ -107,7 +107,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
...
@@ -107,7 +107,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
return
variance
return
variance
def
step
(
self
,
residual
,
image
,
t
,
output_pred_x_0
=
False
):
def
step
(
self
,
residual
,
image
,
t
):
# 1. compute alphas, betas
# 1. compute alphas, betas
alpha_prod_t
=
self
.
get_alpha_prod
(
t
)
alpha_prod_t
=
self
.
get_alpha_prod
(
t
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
t
-
1
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
t
-
1
)
...
@@ -138,4 +138,4 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
...
@@ -138,4 +138,4 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
num_
timesteps
return
self
.
timesteps
src/diffusers/schedulers/glide_ddim.py
View file @
2d97544d
...
@@ -32,12 +32,12 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin):
...
@@ -32,12 +32,12 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin):
timesteps
=
timesteps
,
timesteps
=
timesteps
,
beta_schedule
=
beta_schedule
,
beta_schedule
=
beta_schedule
,
)
)
self
.
num_
timesteps
=
int
(
timesteps
)
self
.
timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
# Linear schedule from Ho et al, extended to work for any number of
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
# diffusion steps.
scale
=
1000
/
self
.
num_
timesteps
scale
=
1000
/
self
.
timesteps
beta_start
=
scale
*
0.0001
beta_start
=
scale
*
0.0001
beta_end
=
scale
*
0.02
beta_end
=
scale
*
0.02
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
...
@@ -88,4 +88,4 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin):
...
@@ -88,4 +88,4 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin):
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
num_
timesteps
return
self
.
timesteps
tests/test_modeling_utils.py
View file @
2d97544d
...
@@ -75,16 +75,18 @@ class ModelTesterMixin(unittest.TestCase):
...
@@ -75,16 +75,18 @@ class ModelTesterMixin(unittest.TestCase):
sizes
=
(
32
,
32
)
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
])
time_step
=
torch
.
tensor
([
10
])
.
to
(
torch_device
)
return
(
noise
,
time_step
)
return
(
noise
,
time_step
)
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
model
.
to
(
torch_device
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
model
.
save_pretrained
(
tmpdirname
)
new_model
=
UNetModel
.
from_pretrained
(
tmpdirname
)
new_model
=
UNetModel
.
from_pretrained
(
tmpdirname
)
new_model
.
to
(
torch_device
)
dummy_input
=
self
.
dummy_input
dummy_input
=
self
.
dummy_input
...
@@ -95,6 +97,7 @@ class ModelTesterMixin(unittest.TestCase):
...
@@ -95,6 +97,7 @@ class ModelTesterMixin(unittest.TestCase):
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
to
(
torch_device
)
image
=
model
(
*
self
.
dummy_input
)
image
=
model
(
*
self
.
dummy_input
)
...
...
tests/test_scheduler.py
View file @
2d97544d
...
@@ -26,8 +26,8 @@ torch.backends.cuda.matmul.allow_tf32 = False
...
@@ -26,8 +26,8 @@ torch.backends.cuda.matmul.allow_tf32 = False
class
SchedulerCommonTest
(
unittest
.
TestCase
):
class
SchedulerCommonTest
(
unittest
.
TestCase
):
scheduler_classes
=
()
scheduler_class
=
None
forward_default_kwargs
=
()
@
property
@
property
def
dummy_image
(
self
):
def
dummy_image
(
self
):
...
@@ -38,42 +38,271 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -38,42 +38,271 @@ class SchedulerCommonTest(unittest.TestCase):
image
=
np
.
random
.
rand
(
batch_size
,
num_channels
,
height
,
width
)
image
=
np
.
random
.
rand
(
batch_size
,
num_channels
,
height
,
width
)
return
image
return
torch
.
tensor
(
image
)
@
property
def
dummy_image_deter
(
self
):
batch_size
=
4
num_channels
=
3
height
=
8
width
=
8
num_elems
=
batch_size
*
num_channels
*
height
*
width
image
=
np
.
arange
(
num_elems
)
image
=
image
.
reshape
(
num_channels
,
height
,
width
,
batch_size
)
image
=
image
/
num_elems
image
=
image
.
transpose
(
3
,
0
,
1
,
2
)
return
torch
.
tensor
(
image
)
def
get_scheduler_config
(
self
):
def
get_scheduler_config
(
self
):
raise
NotImplementedError
raise
NotImplementedError
def
dummy_model
(
self
):
def
dummy_model
(
self
):
def
model
(
image
,
residual
,
t
,
*
args
):
def
model
(
image
,
t
,
*
args
):
return
(
image
+
residual
)
*
t
/
(
t
+
1
)
return
image
*
t
/
(
t
+
1
)
return
model
return
model
def
check_over_configs
(
self
,
time_step
=
0
,
**
config
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
image
=
self
.
dummy_image
residual
=
0.1
*
image
scheduler_config
=
self
.
get_scheduler_config
(
**
config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
assert
(
output
-
new_output
).
abs
().
sum
()
<
1e-5
,
"Scheduler outputs are not identical"
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
.
update
(
forward_kwargs
)
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
image
=
self
.
dummy_image
residual
=
0.1
*
image
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
image
,
time_step
,
**
kwargs
)
assert
(
output
-
new_output
).
abs
().
sum
()
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
for
scheduler_class
in
self
.
scheduler_classes
:
image
=
self
.
dummy_image
image
=
self
.
dummy_image
residual
=
0.1
*
image
residual
=
0.1
*
image
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
self
.
scheduler_class
(
scheduler_config
()
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_pretrained
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
self
.
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
image
,
1
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
image
,
1
,
**
kwargs
)
output
=
scheduler
(
residual
,
image
,
1
)
assert
(
output
-
new_output
).
abs
().
sum
()
<
1e-5
,
"Scheduler outputs are not identical"
new_output
=
new_scheduler
(
residual
,
image
,
1
)
import
ipdb
;
ipdb
.
set_trace
()
def
test_step_shape
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
def
test_step
(
self
)
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
self
.
scheduler_class
(
scheduler_config
()
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
image
=
self
.
dummy_image
image
=
self
.
dummy_image
residual
=
0.1
*
image
residual
=
0.1
*
image
output_0
=
scheduler
(
residual
,
image
,
0
)
output_0
=
scheduler
.
step
(
residual
,
image
,
0
,
**
kwargs
)
output_1
=
scheduler
(
residual
,
image
,
1
)
output_1
=
scheduler
.
step
(
residual
,
image
,
1
,
**
kwargs
)
self
.
assertEqual
(
output_0
.
shape
,
image
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
image
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
class
DDPMSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
GaussianDDPMScheduler
,)
def
get_scheduler_config
(
self
,
**
kwargs
):
config
=
{
"timesteps"
:
1000
,
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"variance_type"
:
"fixed_small"
,
"clip_predicted_image"
:
True
}
config
.
update
(
**
kwargs
)
return
config
def
test_timesteps
(
self
):
for
timesteps
in
[
1
,
5
,
100
,
1000
]:
self
.
check_over_configs
(
timesteps
=
timesteps
)
def
test_betas
(
self
):
for
beta_start
,
beta_end
in
zip
([
0.0001
,
0.001
,
0.01
,
0.1
],
[
0.002
,
0.02
,
0.2
,
2
]):
self
.
check_over_configs
(
beta_start
=
beta_start
,
beta_end
=
beta_end
)
def
test_schedules
(
self
):
for
schedule
in
[
"linear"
,
"squaredcos_cap_v2"
]:
self
.
check_over_configs
(
beta_schedule
=
schedule
)
def
test_variance_type
(
self
):
for
variance
in
[
"fixed_small"
,
"fixed_large"
,
"other"
]:
self
.
check_over_configs
(
variance_type
=
variance
)
def
test_clip_image
(
self
):
for
clip_predicted_image
in
[
True
,
False
]:
self
.
check_over_configs
(
clip_predicted_image
=
clip_predicted_image
)
def
test_time_indices
(
self
):
for
t
in
[
0
,
500
,
999
]:
self
.
check_over_forward
(
time_step
=
t
)
def
test_variance
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
assert
(
scheduler
.
get_variance
(
0
)
-
0.0
).
abs
().
sum
()
<
1e-5
assert
(
scheduler
.
get_variance
(
487
)
-
0.00979
).
abs
().
sum
()
<
1e-5
assert
(
scheduler
.
get_variance
(
999
)
-
0.02
).
abs
().
sum
()
<
1e-5
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
num_trained_timesteps
=
len
(
scheduler
)
model
=
self
.
dummy_model
()
image
=
self
.
dummy_image_deter
for
t
in
reversed
(
range
(
num_trained_timesteps
)):
# 1. predict noise residual
residual
=
model
(
image
,
t
)
# 2. predict previous mean of image x_t-1
pred_prev_image
=
scheduler
.
step
(
residual
,
image
,
t
)
if
t
>
0
:
noise
=
self
.
dummy_image_deter
variance
=
scheduler
.
get_variance
(
t
).
sqrt
()
*
noise
image
=
pred_prev_image
+
variance
result_sum
=
image
.
abs
().
sum
()
result_mean
=
image
.
abs
().
mean
()
assert
result_sum
.
item
()
-
732.9947
<
1e-3
assert
result_mean
.
item
()
-
0.9544
<
1e-3
class
DDIMSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
DDIMScheduler
,)
forward_default_kwargs
=
((
"num_inference_steps"
,
50
),
(
"eta"
,
0.0
))
def
get_scheduler_config
(
self
,
**
kwargs
):
config
=
{
"timesteps"
:
1000
,
"beta_start"
:
0.0001
,
"beta_end"
:
0.02
,
"beta_schedule"
:
"linear"
,
"clip_predicted_image"
:
True
}
config
.
update
(
**
kwargs
)
return
config
def
test_timesteps
(
self
):
for
timesteps
in
[
1
,
5
,
100
,
1000
]:
self
.
check_over_configs
(
timesteps
=
timesteps
)
def
test_betas
(
self
):
for
beta_start
,
beta_end
in
zip
([
0.0001
,
0.001
,
0.01
,
0.1
],
[
0.002
,
0.02
,
0.2
,
2
]):
self
.
check_over_configs
(
beta_start
=
beta_start
,
beta_end
=
beta_end
)
def
test_schedules
(
self
):
for
schedule
in
[
"linear"
,
"squaredcos_cap_v2"
]:
self
.
check_over_configs
(
beta_schedule
=
schedule
)
def
test_clip_image
(
self
):
for
clip_predicted_image
in
[
True
,
False
]:
self
.
check_over_configs
(
clip_predicted_image
=
clip_predicted_image
)
def
test_time_indices
(
self
):
for
t
in
[
1
,
10
,
49
]:
self
.
check_over_forward
(
time_step
=
t
)
def
test_inference_steps
(
self
):
for
t
,
num_inference_steps
in
zip
([
1
,
10
,
50
],
[
10
,
50
,
500
]):
self
.
check_over_forward
(
time_step
=
t
,
num_inference_steps
=
num_inference_steps
)
def
test_eta
(
self
):
for
t
,
eta
in
zip
([
1
,
10
,
49
],
[
0.0
,
0.5
,
1.0
]):
self
.
check_over_forward
(
time_step
=
t
,
eta
=
eta
)
def
test_variance
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
assert
(
scheduler
.
get_variance
(
0
,
50
)
-
0.0
).
abs
().
sum
()
<
1e-5
assert
(
scheduler
.
get_variance
(
21
,
50
)
-
0.14771
).
abs
().
sum
()
<
1e-5
assert
(
scheduler
.
get_variance
(
49
,
50
)
-
0.32460
).
abs
().
sum
()
<
1e-5
assert
(
scheduler
.
get_variance
(
0
,
1000
)
-
0.0
).
abs
().
sum
()
<
1e-5
assert
(
scheduler
.
get_variance
(
487
,
1000
)
-
0.00979
).
abs
().
sum
()
<
1e-5
assert
(
scheduler
.
get_variance
(
999
,
1000
)
-
0.02
).
abs
().
sum
()
<
1e-5
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
num_inference_steps
,
eta
=
10
,
0.1
num_trained_timesteps
=
len
(
scheduler
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
model
=
self
.
dummy_model
()
image
=
self
.
dummy_image_deter
for
t
in
reversed
(
range
(
num_inference_steps
)):
residual
=
model
(
image
,
inference_step_times
[
t
])
pred_prev_image
=
scheduler
.
step
(
residual
,
image
,
t
,
num_inference_steps
,
eta
)
variance
=
0
if
eta
>
0
:
noise
=
self
.
dummy_image_deter
variance
=
scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
image
=
pred_prev_image
+
variance
result_sum
=
image
.
abs
().
sum
()
result_mean
=
image
.
abs
().
mean
()
assert
result_sum
.
item
()
-
270.6214
<
1e-3
assert
result_mean
.
item
()
-
0.3524
<
1e-3
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment