Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
f448360b
Unverified
Commit
f448360b
authored
Jul 15, 2022
by
Patrick von Platen
Committed by
GitHub
Jul 15, 2022
Browse files
Finish scheduler API (#91)
* finish * up
parent
97e1e3ba
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
233 additions
and
188 deletions
+233
-188
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+4
-4
src/diffusers/models/unet_sde_score_estimation.py
src/diffusers/models/unet_sde_score_estimation.py
+2
-2
src/diffusers/models/unet_unconditional.py
src/diffusers/models/unet_unconditional.py
+5
-3
src/diffusers/pipelines/ddim/pipeline_ddim.py
src/diffusers/pipelines/ddim/pipeline_ddim.py
+11
-29
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+6
-6
src/diffusers/pipelines/glide/pipeline_glide.py
src/diffusers/pipelines/glide/pipeline_glide.py
+11
-13
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+6
-6
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
...tent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+12
-33
src/diffusers/pipelines/pndm/pipeline_pndm.py
src/diffusers/pipelines/pndm/pipeline_pndm.py
+7
-7
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+43
-15
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+11
-2
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+20
-4
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+20
-21
tests/test_scheduler.py
tests/test_scheduler.py
+75
-43
No files found.
src/diffusers/models/unet_glide.py
View file @
f448360b
...
@@ -438,8 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel):
...
@@ -438,8 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel):
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
def
forward
(
self
,
sample
,
step
_value
,
transformer_out
=
None
):
def
forward
(
self
,
sample
,
time
step
,
transformer_out
=
None
):
timesteps
=
step
_value
timesteps
=
time
step
x
=
sample
x
=
sample
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
emb
=
self
.
time_embed
(
...
@@ -530,8 +530,8 @@ class GlideSuperResUNetModel(GlideUNetModel):
...
@@ -530,8 +530,8 @@ class GlideSuperResUNetModel(GlideUNetModel):
resblock_updown
=
resblock_updown
,
resblock_updown
=
resblock_updown
,
)
)
def
forward
(
self
,
sample
,
step
_value
,
low_res
=
None
):
def
forward
(
self
,
sample
,
time
step
,
low_res
=
None
):
timesteps
=
step
_value
timesteps
=
time
step
x
=
sample
x
=
sample
_
,
_
,
new_height
,
new_width
=
x
.
shape
_
,
_
,
new_height
,
new_width
=
x
.
shape
upsampled
=
F
.
interpolate
(
low_res
,
(
new_height
,
new_width
),
mode
=
"bilinear"
)
upsampled
=
F
.
interpolate
(
low_res
,
(
new_height
,
new_width
),
mode
=
"bilinear"
)
...
...
src/diffusers/models/unet_sde_score_estimation.py
View file @
f448360b
...
@@ -323,8 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
...
@@ -323,8 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
self
.
all_modules
=
nn
.
ModuleList
(
modules
)
def
forward
(
self
,
sample
,
step
_value
,
sigmas
=
None
):
def
forward
(
self
,
sample
,
time
step
,
sigmas
=
None
):
timesteps
=
step
_value
timesteps
=
time
step
x
=
sample
x
=
sample
# timestep/noise_level embedding; only for continuous training
# timestep/noise_level embedding; only for continuous training
modules
=
self
.
all_modules
modules
=
self
.
all_modules
...
...
src/diffusers/models/unet_unconditional.py
View file @
f448360b
...
@@ -254,7 +254,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -254,7 +254,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# ======================================
# ======================================
def
forward
(
def
forward
(
self
,
sample
:
torch
.
FloatTensor
,
step
_value
:
Union
[
torch
.
Tensor
,
float
,
int
]
self
,
sample
:
torch
.
FloatTensor
,
time
step
:
Union
[
torch
.
Tensor
,
float
,
int
]
)
->
Dict
[
str
,
torch
.
FloatTensor
]:
)
->
Dict
[
str
,
torch
.
FloatTensor
]:
# TODO(PVP) - to delete later at release
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
...
@@ -263,10 +263,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
...
@@ -263,10 +263,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self
.
set_weights
()
self
.
set_weights
()
# ======================================
# ======================================
# 1. time step embeddings
# 1. time step embeddings
-> make correct tensor
timesteps
=
step
_value
timesteps
=
time
step
if
not
torch
.
is_tensor
(
timesteps
):
if
not
torch
.
is_tensor
(
timesteps
):
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
timesteps
=
torch
.
tensor
([
timesteps
],
dtype
=
torch
.
long
,
device
=
sample
.
device
)
elif
torch
.
is_tensor
(
timesteps
)
and
len
(
timesteps
.
shape
)
==
0
:
timesteps
=
timesteps
[
None
].
to
(
sample
.
device
)
t_emb
=
get_timestep_embedding
(
t_emb
=
get_timestep_embedding
(
timesteps
,
timesteps
,
...
...
src/diffusers/pipelines/ddim/pipeline_ddim.py
View file @
f448360b
...
@@ -22,19 +22,16 @@ from ...pipeline_utils import DiffusionPipeline
...
@@ -22,19 +22,16 @@ from ...pipeline_utils import DiffusionPipeline
class
DDIMPipeline
(
DiffusionPipeline
):
class
DDIMPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
unet
,
scheduler
):
super
().
__init__
()
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
# eta corresponds to η in paper and should be between [0, 1]
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
num_trained_timesteps
=
self
.
noise_scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
# Sample gaussian noise to begin loop
# Sample gaussian noise to begin loop
...
@@ -44,34 +41,19 @@ class DDIMPipeline(DiffusionPipeline):
...
@@ -44,34 +41,19 @@ class DDIMPipeline(DiffusionPipeline):
)
)
image
=
image
.
to
(
torch_device
)
image
=
image
.
to
(
torch_device
)
#
See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
#
set step values
# Ideally, read DDIM paper in-detail understanding
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
# Notation (<variable name> -> <name in paper>
for
t
in
tqdm
.
tqdm
(
self
.
scheduler
.
timesteps
):
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# 1. predict noise residual
# 1. predict noise residual
with
torch
.
no_grad
():
with
torch
.
no_grad
():
residual
=
self
.
unet
(
image
,
inference_step_times
[
t
]
)
residual
=
self
.
unet
(
image
,
t
)
if
isinstance
(
residual
,
dict
):
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
residual
=
residual
[
"sample"
]
# 2. predict previous mean of image x_t-1
# 2. predict previous mean of image x_t-1 and add variance depending on eta
pred_prev_image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
,
num_inference_steps
,
eta
)
# do x_t -> x_t-1
image
=
self
.
scheduler
.
step
(
residual
,
t
,
image
,
eta
)[
"prev_sample"
]
# 3. optionally sample variance
variance
=
0
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
return
image
return
{
"sample"
:
image
}
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
f448360b
...
@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
...
@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
class
DDPMPipeline
(
DiffusionPipeline
):
class
DDPMPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
unet
,
scheduler
):
super
().
__init__
()
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
):
if
torch_device
is
None
:
if
torch_device
is
None
:
...
@@ -40,7 +40,7 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -40,7 +40,7 @@ class DDPMPipeline(DiffusionPipeline):
)
)
image
=
image
.
to
(
torch_device
)
image
=
image
.
to
(
torch_device
)
num_prediction_steps
=
len
(
self
.
noise_
scheduler
)
num_prediction_steps
=
len
(
self
.
scheduler
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_prediction_steps
)),
total
=
num_prediction_steps
):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_prediction_steps
)),
total
=
num_prediction_steps
):
# 1. predict noise residual
# 1. predict noise residual
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -50,13 +50,13 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -50,13 +50,13 @@ class DDPMPipeline(DiffusionPipeline):
residual
=
residual
[
"sample"
]
residual
=
residual
[
"sample"
]
# 2. predict previous mean of image x_t-1
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_
scheduler
.
step
(
residual
,
image
,
t
)
pred_prev_image
=
self
.
scheduler
.
step
(
residual
,
t
,
image
)[
"prev_sample"
]
# 3. optionally sample variance
# 3. optionally sample variance
variance
=
0
variance
=
0
if
t
>
0
:
if
t
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_
scheduler
.
get_variance
(
t
).
sqrt
()
*
noise
variance
=
self
.
scheduler
.
get_variance
(
t
).
sqrt
()
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
image
=
pred_prev_image
+
variance
...
...
src/diffusers/pipelines/glide/pipeline_glide.py
View file @
f448360b
...
@@ -713,20 +713,20 @@ class GlidePipeline(DiffusionPipeline):
...
@@ -713,20 +713,20 @@ class GlidePipeline(DiffusionPipeline):
def
__init__
(
def
__init__
(
self
,
self
,
text_unet
:
GlideTextToImageUNetModel
,
text_unet
:
GlideTextToImageUNetModel
,
text_
noise_
scheduler
:
DDPMScheduler
,
text_scheduler
:
DDPMScheduler
,
text_encoder
:
CLIPTextModel
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
GPT2Tokenizer
,
tokenizer
:
GPT2Tokenizer
,
upscale_unet
:
GlideSuperResUNetModel
,
upscale_unet
:
GlideSuperResUNetModel
,
upscale_
noise_
scheduler
:
DDIMScheduler
,
upscale_scheduler
:
DDIMScheduler
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
self
.
register_modules
(
text_unet
=
text_unet
,
text_unet
=
text_unet
,
text_
noise_
scheduler
=
text_
noise_
scheduler
,
text_scheduler
=
text_scheduler
,
text_encoder
=
text_encoder
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
upscale_unet
=
upscale_unet
,
upscale_unet
=
upscale_unet
,
upscale_
noise_
scheduler
=
upscale_
noise_
scheduler
,
upscale_scheduler
=
upscale_scheduler
,
)
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -777,20 +777,20 @@ class GlidePipeline(DiffusionPipeline):
...
@@ -777,20 +777,20 @@ class GlidePipeline(DiffusionPipeline):
transformer_out
=
self
.
text_encoder
(
input_ids
,
attention_mask
).
last_hidden_state
transformer_out
=
self
.
text_encoder
(
input_ids
,
attention_mask
).
last_hidden_state
# 3. Run the text2image generation step
# 3. Run the text2image generation step
num_prediction_steps
=
len
(
self
.
text_
noise_
scheduler
)
num_prediction_steps
=
len
(
self
.
text_scheduler
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_prediction_steps
)),
total
=
num_prediction_steps
):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_prediction_steps
)),
total
=
num_prediction_steps
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
time_input
=
torch
.
tensor
([
t
]
*
image
.
shape
[
0
],
device
=
torch_device
)
time_input
=
torch
.
tensor
([
t
]
*
image
.
shape
[
0
],
device
=
torch_device
)
model_output
=
text_model_fn
(
image
,
time_input
,
transformer_out
)
model_output
=
text_model_fn
(
image
,
time_input
,
transformer_out
)
noise_residual
,
model_var_values
=
torch
.
split
(
model_output
,
3
,
dim
=
1
)
noise_residual
,
model_var_values
=
torch
.
split
(
model_output
,
3
,
dim
=
1
)
min_log
=
self
.
text_
noise_
scheduler
.
get_variance
(
t
,
"fixed_small_log"
)
min_log
=
self
.
text_scheduler
.
get_variance
(
t
,
"fixed_small_log"
)
max_log
=
self
.
text_
noise_
scheduler
.
get_variance
(
t
,
"fixed_large_log"
)
max_log
=
self
.
text_scheduler
.
get_variance
(
t
,
"fixed_large_log"
)
# The model_var_values is [-1, 1] for [min_var, max_var].
# The model_var_values is [-1, 1] for [min_var, max_var].
frac
=
(
model_var_values
+
1
)
/
2
frac
=
(
model_var_values
+
1
)
/
2
model_log_variance
=
frac
*
max_log
+
(
1
-
frac
)
*
min_log
model_log_variance
=
frac
*
max_log
+
(
1
-
frac
)
*
min_log
pred_prev_image
=
self
.
text_
noise_
scheduler
.
step
(
noise_residual
,
image
,
t
)
pred_prev_image
=
self
.
text_scheduler
.
step
(
noise_residual
,
image
,
t
)
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
torch_device
)
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
torch_device
)
variance
=
torch
.
exp
(
0.5
*
model_log_variance
)
*
noise
variance
=
torch
.
exp
(
0.5
*
model_log_variance
)
*
noise
...
@@ -814,7 +814,7 @@ class GlidePipeline(DiffusionPipeline):
...
@@ -814,7 +814,7 @@ class GlidePipeline(DiffusionPipeline):
).
to
(
torch_device
)
).
to
(
torch_device
)
image
=
image
*
upsample_temp
image
=
image
*
upsample_temp
num_trained_timesteps
=
self
.
upscale_
noise_
scheduler
.
timesteps
num_trained_timesteps
=
self
.
upscale_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps_upscale
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps_upscale
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps_upscale
)),
total
=
num_inference_steps_upscale
):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps_upscale
)),
total
=
num_inference_steps_upscale
):
...
@@ -825,7 +825,7 @@ class GlidePipeline(DiffusionPipeline):
...
@@ -825,7 +825,7 @@ class GlidePipeline(DiffusionPipeline):
noise_residual
,
pred_variance
=
torch
.
split
(
model_output
,
3
,
dim
=
1
)
noise_residual
,
pred_variance
=
torch
.
split
(
model_output
,
3
,
dim
=
1
)
# 2. predict previous mean of image x_t-1
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
upscale_
noise_
scheduler
.
step
(
pred_prev_image
=
self
.
upscale_scheduler
.
step
(
noise_residual
,
image
,
t
,
num_inference_steps_upscale
,
eta
,
use_clipped_residual
=
True
noise_residual
,
image
,
t
,
num_inference_steps_upscale
,
eta
,
use_clipped_residual
=
True
)
)
...
@@ -833,9 +833,7 @@ class GlidePipeline(DiffusionPipeline):
...
@@ -833,9 +833,7 @@ class GlidePipeline(DiffusionPipeline):
variance
=
0
variance
=
0
if
eta
>
0
:
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
torch_device
)
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
torch_device
)
variance
=
(
variance
=
self
.
upscale_scheduler
.
get_variance
(
t
,
num_inference_steps_upscale
).
sqrt
()
*
eta
*
noise
self
.
upscale_noise_scheduler
.
get_variance
(
t
,
num_inference_steps_upscale
).
sqrt
()
*
eta
*
noise
)
# 4. set current image to prev_image: x_t -> x_t-1
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
image
=
pred_prev_image
+
variance
...
...
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
f448360b
...
@@ -545,10 +545,10 @@ class LDMBertModel(LDMBertPreTrainedModel):
...
@@ -545,10 +545,10 @@ class LDMBertModel(LDMBertPreTrainedModel):
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
class
LatentDiffusionPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
scheduler
):
super
().
__init__
()
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
def
__call__
(
...
@@ -581,7 +581,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
...
@@ -581,7 +581,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)
num_trained_timesteps
=
self
.
noise_
scheduler
.
config
.
timesteps
num_trained_timesteps
=
self
.
scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
torch
.
randn
(
image
=
torch
.
randn
(
...
@@ -622,13 +622,13 @@ class LatentDiffusionPipeline(DiffusionPipeline):
...
@@ -622,13 +622,13 @@ class LatentDiffusionPipeline(DiffusionPipeline):
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
# 2. predict previous mean of image x_t-1
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_
scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
pred_prev_image
=
self
.
scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
# 3. optionally sample variance
# 3. optionally sample variance
variance
=
0
variance
=
0
if
eta
>
0
:
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_
scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
variance
=
self
.
scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
image
=
pred_prev_image
+
variance
...
...
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
View file @
f448360b
...
@@ -6,10 +6,10 @@ from ...pipeline_utils import DiffusionPipeline
...
@@ -6,10 +6,10 @@ from ...pipeline_utils import DiffusionPipeline
class
LatentDiffusionUncondPipeline
(
DiffusionPipeline
):
class
LatentDiffusionUncondPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
vqvae
,
unet
,
scheduler
):
super
().
__init__
()
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
self
.
register_modules
(
vqvae
=
vqvae
,
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
def
__call__
(
...
@@ -28,44 +28,23 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
...
@@ -28,44 +28,23 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
num_trained_timesteps
=
self
.
noise_scheduler
.
config
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
torch
.
randn
(
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
generator
=
generator
,
).
to
(
torch_device
)
).
to
(
torch_device
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# 1. predict noise residual
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
pred_noise_t
=
self
.
unet
(
image
,
timesteps
)
if
isinstance
(
pred_noise_t
,
dict
):
pred_noise_t
=
pred_noise_t
[
"sample"
]
# 2. predict previous mean of image x_t-1
for
t
in
tqdm
.
tqdm
(
self
.
scheduler
.
timesteps
):
p
re
d_prev_image
=
self
.
noise_scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
re
sidual
=
self
.
unet
(
image
,
t
)
# 3. optionally sample variance
if
isinstance
(
residual
,
dict
):
variance
=
0
residual
=
residual
[
"sample"
]
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
# 2. predict previous mean of image x_t-1 and add variance depending on eta
image
=
pred_prev_image
+
variance
# do x_t -> x_t-1
image
=
self
.
scheduler
.
step
(
residual
,
t
,
image
,
eta
)[
"prev_sample"
]
# decode image with vae
# decode image with vae
image
=
self
.
vqvae
.
decode
(
image
)
image
=
self
.
vqvae
.
decode
(
image
)
return
image
return
{
"sample"
:
image
}
src/diffusers/pipelines/pndm/pipeline_pndm.py
View file @
f448360b
...
@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
...
@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
class
PNDMPipeline
(
DiffusionPipeline
):
class
PNDMPipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_
scheduler
):
def
__init__
(
self
,
unet
,
scheduler
):
super
().
__init__
()
super
().
__init__
()
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
# For more information on the sampling method you can take a look at Algorithm 2 of
# For more information on the sampling method you can take a look at Algorithm 2 of
...
@@ -42,7 +42,7 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -42,7 +42,7 @@ class PNDMPipeline(DiffusionPipeline):
)
)
image
=
image
.
to
(
torch_device
)
image
=
image
.
to
(
torch_device
)
prk_time_steps
=
self
.
noise_
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
prk_time_steps
=
self
.
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
range
(
len
(
prk_time_steps
))):
for
t
in
tqdm
.
tqdm
(
range
(
len
(
prk_time_steps
))):
t_orig
=
prk_time_steps
[
t
]
t_orig
=
prk_time_steps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
residual
=
self
.
unet
(
image
,
t_orig
)
...
@@ -50,9 +50,9 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -50,9 +50,9 @@ class PNDMPipeline(DiffusionPipeline):
if
isinstance
(
residual
,
dict
):
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
residual
=
residual
[
"sample"
]
image
=
self
.
noise_
scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
image
=
self
.
scheduler
.
step_prk
(
residual
,
t
,
image
,
num_inference_steps
)
[
"prev_sample"
]
timesteps
=
self
.
noise_
scheduler
.
get_time_steps
(
num_inference_steps
)
timesteps
=
self
.
scheduler
.
get_time_steps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
range
(
len
(
timesteps
))):
for
t
in
tqdm
.
tqdm
(
range
(
len
(
timesteps
))):
t_orig
=
timesteps
[
t
]
t_orig
=
timesteps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
residual
=
self
.
unet
(
image
,
t_orig
)
...
@@ -60,6 +60,6 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -60,6 +60,6 @@ class PNDMPipeline(DiffusionPipeline):
if
isinstance
(
residual
,
dict
):
if
isinstance
(
residual
,
dict
):
residual
=
residual
[
"sample"
]
residual
=
residual
[
"sample"
]
image
=
self
.
noise_
scheduler
.
step_plms
(
residual
,
image
,
t
,
num_inference_steps
)
image
=
self
.
scheduler
.
step_plms
(
residual
,
t
,
image
,
num_inference_steps
)
[
"prev_sample"
]
return
image
return
image
src/diffusers/schedulers/scheduling_ddim.py
View file @
f448360b
...
@@ -16,8 +16,10 @@
...
@@ -16,8 +16,10 @@
# and https://github.com/hojonathanho/diffusion
# and https://github.com/hojonathanho/diffusion
import
math
import
math
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -84,14 +86,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -84,14 +86,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
self
.
one
=
np
.
array
(
1.0
)
self
.
one
=
np
.
array
(
1.0
)
self
.
set_format
(
tensor_format
=
tensor_format
)
# setable values
self
.
num_inference_steps
=
None
self
.
timesteps
=
np
.
arange
(
0
,
self
.
config
.
timesteps
)[::
-
1
].
copy
()
def
get_variance
(
self
,
t
,
num_inference_steps
):
self
.
tensor_format
=
tensor_format
orig_t
=
self
.
config
.
timesteps
//
num_inference_steps
*
t
self
.
set_format
(
tensor_format
=
tensor_format
)
orig_prev_t
=
self
.
config
.
timesteps
//
num_inference_steps
*
(
t
-
1
)
if
t
>
0
else
-
1
alpha_prod_t
=
self
.
alphas_cumprod
[
orig_t
]
def
_get_variance
(
self
,
timestep
,
prev_timestep
):
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
orig_prev_t
]
if
orig_prev_t
>=
0
else
self
.
one
alpha_prod_t
=
self
.
alphas_cumprod
[
timestep
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
prev_timestep
]
if
prev_timestep
>=
0
else
self
.
one
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
beta_prod_t_prev
=
1
-
alpha_prod_t_prev
...
@@ -99,7 +103,22 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -99,7 +103,22 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return
variance
return
variance
def
step
(
self
,
residual
,
sample
,
t
,
num_inference_steps
,
eta
,
use_clipped_residual
=
False
):
def
set_timesteps
(
self
,
num_inference_steps
):
self
.
num_inference_steps
=
num_inference_steps
self
.
timesteps
=
np
.
arange
(
0
,
self
.
config
.
timesteps
,
self
.
config
.
timesteps
//
self
.
num_inference_steps
)[
::
-
1
].
copy
()
self
.
set_format
(
tensor_format
=
self
.
tensor_format
)
def
step
(
self
,
residual
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
eta
,
use_clipped_residual
=
False
,
generator
=
None
,
):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Ideally, read DDIM paper in-detail understanding
...
@@ -111,13 +130,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -111,13 +130,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_sample_direction -> "direction pointingc to x_t"
# - pred_sample_direction -> "direction pointingc to x_t"
# - pred_prev_sample -> "x_t-1"
# - pred_prev_sample -> "x_t-1"
# 1. get actual t and t-1
# 1. get previous step value (=t-1)
orig_t
=
self
.
config
.
timesteps
//
num_inference_steps
*
t
prev_timestep
=
timestep
-
self
.
config
.
timesteps
//
self
.
num_inference_steps
orig_prev_t
=
self
.
config
.
timesteps
//
num_inference_steps
*
(
t
-
1
)
if
t
>
0
else
-
1
# 2. compute alphas, betas
# 2. compute alphas, betas
alpha_prod_t
=
self
.
alphas_cumprod
[
orig_t
]
alpha_prod_t
=
self
.
alphas_cumprod
[
timestep
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
orig_
prev_t
]
if
orig_prev_t
>=
0
else
self
.
one
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
prev_t
imestep
]
if
prev_timestep
>=
0
else
self
.
one
beta_prod_t
=
1
-
alpha_prod_t
beta_prod_t
=
1
-
alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# 3. compute predicted original sample from predicted noise also called
...
@@ -130,7 +148,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -130,7 +148,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance
=
self
.
get_variance
(
t
,
num_inference_
step
s
)
variance
=
self
.
_
get_variance
(
t
imestep
,
prev_time
step
)
std_dev_t
=
eta
*
variance
**
(
0.5
)
std_dev_t
=
eta
*
variance
**
(
0.5
)
if
use_clipped_residual
:
if
use_clipped_residual
:
...
@@ -141,9 +159,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -141,9 +159,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
pred_sample_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
)
**
(
0.5
)
*
residual
pred_sample_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
)
**
(
0.5
)
*
residual
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_sample
=
alpha_prod_t_prev
**
(
0.5
)
*
pred_original_sample
+
pred_sample_direction
prev_sample
=
alpha_prod_t_prev
**
(
0.5
)
*
pred_original_sample
+
pred_sample_direction
if
eta
>
0
:
device
=
residual
.
device
if
torch
.
is_tensor
(
residual
)
else
"cpu"
noise
=
torch
.
randn
(
residual
.
shape
,
generator
=
generator
).
to
(
device
)
variance
=
self
.
_get_variance
(
timestep
,
prev_timestep
)
**
(
0.5
)
*
eta
*
noise
if
not
torch
.
is_tensor
(
residual
):
variance
=
variance
.
numpy
()
prev_sample
=
prev_sample
+
variance
return
pre
d_
prev_sample
return
{
"
pre
v_sample"
:
prev_sample
}
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
f448360b
...
@@ -15,8 +15,10 @@
...
@@ -15,8 +15,10 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
math
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -112,7 +114,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -112,7 +114,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return
variance
return
variance
def
step
(
self
,
residual
,
sample
,
t
,
predict_epsilon
=
True
):
def
step
(
self
,
residual
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
predict_epsilon
=
True
,
):
t
=
timestep
# 1. compute alphas, betas
# 1. compute alphas, betas
alpha_prod_t
=
self
.
alphas_cumprod
[
t
]
alpha_prod_t
=
self
.
alphas_cumprod
[
t
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
t
-
1
]
if
t
>
0
else
self
.
one
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
t
-
1
]
if
t
>
0
else
self
.
one
...
@@ -139,7 +148,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -139,7 +148,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample
=
pred_original_sample_coeff
*
pred_original_sample
+
current_sample_coeff
*
sample
pred_prev_sample
=
pred_original_sample_coeff
*
pred_original_sample
+
current_sample_coeff
*
sample
return
pred_prev_sample
return
{
"prev_sample"
:
pred_prev_sample
}
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
f448360b
...
@@ -15,8 +15,10 @@
...
@@ -15,8 +15,10 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
math
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
import
torch
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
...
@@ -126,7 +128,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -126,7 +128,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
raise
ValueError
(
f
"mode
{
self
.
mode
}
does not exist."
)
raise
ValueError
(
f
"mode
{
self
.
mode
}
does not exist."
)
def
step_prk
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
def
step_prk
(
self
,
residual
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
num_inference_steps
,
):
t
=
timestep
prk_time_steps
=
self
.
get_prk_time_steps
(
num_inference_steps
)
prk_time_steps
=
self
.
get_prk_time_steps
(
num_inference_steps
)
t_orig
=
prk_time_steps
[
t
//
4
*
4
]
t_orig
=
prk_time_steps
[
t
//
4
*
4
]
...
@@ -147,9 +156,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -147,9 +156,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# cur_sample should not be `None`
# cur_sample should not be `None`
cur_sample
=
self
.
cur_sample
if
self
.
cur_sample
is
not
None
else
sample
cur_sample
=
self
.
cur_sample
if
self
.
cur_sample
is
not
None
else
sample
return
self
.
get_prev_sample
(
cur_sample
,
t_orig
,
t_orig_prev
,
residual
)
return
{
"prev_sample"
:
self
.
get_prev_sample
(
cur_sample
,
t_orig
,
t_orig_prev
,
residual
)
}
def
step_plms
(
self
,
residual
,
sample
,
t
,
num_inference_steps
):
def
step_plms
(
self
,
residual
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
num_inference_steps
,
):
t
=
timestep
if
len
(
self
.
ets
)
<
3
:
if
len
(
self
.
ets
)
<
3
:
raise
ValueError
(
raise
ValueError
(
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
f
"
{
self
.
__class__
}
can only be run AFTER scheduler has been run "
...
@@ -166,7 +182,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -166,7 +182,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
residual
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
return
self
.
get_prev_sample
(
sample
,
t_orig
,
t_orig_prev
,
residual
)
return
{
"prev_sample"
:
self
.
get_prev_sample
(
sample
,
t_orig
,
t_orig_prev
,
residual
)
}
def
get_prev_sample
(
self
,
sample
,
t_orig
,
t_orig_prev
,
residual
):
def
get_prev_sample
(
self
,
sample
,
t_orig
,
t_orig_prev
,
residual
):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
...
...
tests/test_modeling_utils.py
View file @
f448360b
...
@@ -165,7 +165,7 @@ class ModelTesterMixin:
...
@@ -165,7 +165,7 @@ class ModelTesterMixin:
# signature.parameters is an OrderedDict => so arg_names order is deterministic
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names
=
[
*
signature
.
parameters
.
keys
()]
arg_names
=
[
*
signature
.
parameters
.
keys
()]
expected_arg_names
=
[
"sample"
,
"step
_value
"
]
expected_arg_names
=
[
"sample"
,
"
time
step"
]
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
self
.
assertListEqual
(
arg_names
[:
2
],
expected_arg_names
)
def
test_model_from_config
(
self
):
def
test_model_from_config
(
self
):
...
@@ -248,7 +248,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -248,7 +248,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
}
@
property
@
property
def
input_shape
(
self
):
def
input_shape
(
self
):
...
@@ -323,7 +323,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
...
@@ -323,7 +323,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
low_res
=
torch
.
randn
((
batch_size
,
3
)
+
low_res_size
).
to
(
torch_device
)
low_res
=
torch
.
randn
((
batch_size
,
3
)
+
low_res_size
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
,
"low_res"
:
low_res
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
,
"low_res"
:
low_res
}
@
property
@
property
def
input_shape
(
self
):
def
input_shape
(
self
):
...
@@ -414,7 +414,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -414,7 +414,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
emb
=
torch
.
randn
((
batch_size
,
seq_len
,
transformer_dim
)).
to
(
torch_device
)
emb
=
torch
.
randn
((
batch_size
,
seq_len
,
transformer_dim
)).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
,
"transformer_out"
:
emb
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
,
"transformer_out"
:
emb
}
@
property
@
property
def
input_shape
(
self
):
def
input_shape
(
self
):
...
@@ -506,7 +506,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -506,7 +506,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
}
@
property
@
property
def
input_shape
(
self
):
def
input_shape
(
self
):
...
@@ -601,7 +601,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -601,7 +601,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
10
]).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
10
]).
to
(
torch_device
)
return
{
"sample"
:
noise
,
"step
_value
"
:
time_step
}
return
{
"sample"
:
noise
,
"
time
step"
:
time_step
}
@
property
@
property
def
input_shape
(
self
):
def
input_shape
(
self
):
...
@@ -899,8 +899,8 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -899,8 +899,8 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
)
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
ddpm
.
noise_
scheduler
.
num_timesteps
=
10
ddpm
.
scheduler
.
num_timesteps
=
10
ddpm_from_hub
.
noise_
scheduler
.
num_timesteps
=
10
ddpm_from_hub
.
scheduler
.
num_timesteps
=
10
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
...
@@ -915,10 +915,10 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -915,10 +915,10 @@ class PipelineTesterMixin(unittest.TestCase):
model_id
=
"fusing/ddpm-cifar10"
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_
scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
noise_
scheduler
=
noise_
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
image
=
ddpm
(
generator
=
generator
)
...
@@ -936,13 +936,12 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -936,13 +936,12 @@ class PipelineTesterMixin(unittest.TestCase):
model_id
=
"fusing/ddpm-lsun-bedroom-ema"
model_id
=
"fusing/ddpm-lsun-bedroom-ema"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)
image
=
ddpm
(
generator
=
generator
)
[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
@@ -957,12 +956,12 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -957,12 +956,12 @@ class PipelineTesterMixin(unittest.TestCase):
model_id
=
"fusing/ddpm-cifar10"
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
@@ -977,9 +976,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -977,9 +976,9 @@ class PipelineTesterMixin(unittest.TestCase):
model_id
=
"fusing/ddpm-cifar10"
model_id
=
"fusing/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
noise_
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
noise_
scheduler
=
noise_
scheduler
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
pndm
(
generator
=
generator
)
image
=
pndm
(
generator
=
generator
)
...
@@ -1074,7 +1073,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -1074,7 +1073,7 @@ class PipelineTesterMixin(unittest.TestCase):
ldm
=
LatentDiffusionUncondPipeline
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
,
ldm
=
True
)
ldm
=
LatentDiffusionUncondPipeline
.
from_pretrained
(
"fusing/latent-diffusion-celeba-256"
,
ldm
=
True
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
)
[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
...
...
tests/test_scheduler.py
View file @
f448360b
...
@@ -68,6 +68,8 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -68,6 +68,8 @@ class SchedulerCommonTest(unittest.TestCase):
def
check_over_configs
(
self
,
time_step
=
0
,
**
config
):
def
check_over_configs
(
self
,
time_step
=
0
,
**
config
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
...
@@ -80,8 +82,14 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -80,8 +82,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
...
@@ -89,6 +97,8 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -89,6 +97,8 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
.
update
(
forward_kwargs
)
kwargs
.
update
(
forward_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
residual
=
0.1
*
sample
...
@@ -101,14 +111,24 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -101,14 +111,24 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
torch
.
manual_seed
(
0
)
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
torch
.
manual_seed
(
0
)
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
residual
=
0.1
*
sample
...
@@ -120,14 +140,22 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -120,14 +140,22 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
output
=
scheduler
.
step
(
residual
,
sample
,
1
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
1
,
**
kwargs
)
scheduler
.
set_timesteps
(
num_inference_steps
)
new_scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output
=
scheduler
.
step
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_step_shape
(
self
):
def
test_step_shape
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
...
@@ -135,8 +163,13 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -135,8 +163,13 @@ class SchedulerCommonTest(unittest.TestCase):
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
residual
=
0.1
*
sample
output_0
=
scheduler
.
step
(
residual
,
sample
,
0
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
output_1
=
scheduler
.
step
(
residual
,
sample
,
1
,
**
kwargs
)
scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output_0
=
scheduler
.
step
(
residual
,
0
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
...
@@ -144,6 +177,8 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -144,6 +177,8 @@ class SchedulerCommonTest(unittest.TestCase):
def
test_pytorch_equal_numpy
(
self
):
def
test_pytorch_equal_numpy
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
residual
=
0.1
*
sample
...
@@ -156,8 +191,14 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -156,8 +191,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_pt
=
scheduler_class
(
tensor_format
=
"pt"
,
**
scheduler_config
)
scheduler_pt
=
scheduler_class
(
tensor_format
=
"pt"
,
**
scheduler_config
)
output
=
scheduler
.
step
(
residual
,
sample
,
1
,
**
kwargs
)
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
output_pt
=
scheduler_pt
.
step
(
residual_pt
,
sample_pt
,
1
,
**
kwargs
)
scheduler
.
set_timesteps
(
num_inference_steps
)
scheduler_pt
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output
=
scheduler
.
step
(
residual
,
1
,
sample
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step
(
residual_pt
,
1
,
sample_pt
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
...
@@ -226,7 +267,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
...
@@ -226,7 +267,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
residual
=
model
(
sample
,
t
)
residual
=
model
(
sample
,
t
)
# 2. predict previous mean of sample x_t-1
# 2. predict previous mean of sample x_t-1
pred_prev_sample
=
scheduler
.
step
(
residual
,
sample
,
t
)
pred_prev_sample
=
scheduler
.
step
(
residual
,
t
,
sample
)[
"prev_sample"
]
if
t
>
0
:
if
t
>
0
:
noise
=
self
.
dummy_sample_deter
noise
=
self
.
dummy_sample_deter
...
@@ -243,7 +284,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
...
@@ -243,7 +284,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
class
DDIMSchedulerTest
(
SchedulerCommonTest
):
class
DDIMSchedulerTest
(
SchedulerCommonTest
):
scheduler_classes
=
(
DDIMScheduler
,)
scheduler_classes
=
(
DDIMScheduler
,)
forward_default_kwargs
=
((
"num_inference_steps"
,
50
)
,
(
"eta"
,
0.0
)
)
forward_default_kwargs
=
(
(
"eta"
,
0.0
),
(
"num_inference_steps"
,
50
))
def
get_scheduler_config
(
self
,
**
kwargs
):
def
get_scheduler_config
(
self
,
**
kwargs
):
config
=
{
config
=
{
...
@@ -258,7 +299,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
...
@@ -258,7 +299,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
return
config
return
config
def
test_timesteps
(
self
):
def
test_timesteps
(
self
):
for
timesteps
in
[
1
,
5
,
1
00
,
1000
]:
for
timesteps
in
[
1
00
,
500
,
1000
]:
self
.
check_over_configs
(
timesteps
=
timesteps
)
self
.
check_over_configs
(
timesteps
=
timesteps
)
def
test_betas
(
self
):
def
test_betas
(
self
):
...
@@ -279,7 +320,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
...
@@ -279,7 +320,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
def
test_inference_steps
(
self
):
def
test_inference_steps
(
self
):
for
t
,
num_inference_steps
in
zip
([
1
,
10
,
50
],
[
10
,
50
,
500
]):
for
t
,
num_inference_steps
in
zip
([
1
,
10
,
50
],
[
10
,
50
,
500
]):
self
.
check_over_forward
(
time_step
=
t
,
num_inference_steps
=
num_inference_steps
)
self
.
check_over_forward
(
num_inference_steps
=
num_inference_steps
)
def
test_eta
(
self
):
def
test_eta
(
self
):
for
t
,
eta
in
zip
([
1
,
10
,
49
],
[
0.0
,
0.5
,
1.0
]):
for
t
,
eta
in
zip
([
1
,
10
,
49
],
[
0.0
,
0.5
,
1.0
]):
...
@@ -290,43 +331,34 @@ class DDIMSchedulerTest(SchedulerCommonTest):
...
@@ -290,43 +331,34 @@ class DDIMSchedulerTest(SchedulerCommonTest):
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
0
,
5
0
)
-
0.0
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
0
,
0
)
-
0.0
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
21
,
5
0
)
-
0.14771
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
420
,
40
0
)
-
0.14771
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
49
,
5
0
)
-
0.32460
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
980
,
96
0
)
-
0.32460
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
0
,
100
0
)
-
0.0
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
0
,
0
)
-
0.0
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
487
,
1000
)
-
0.00979
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
487
,
486
)
-
0.00979
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
get_variance
(
999
,
1000
)
-
0.02
))
<
1e-5
assert
np
.
sum
(
np
.
abs
(
scheduler
.
_
get_variance
(
999
,
998
)
-
0.02
))
<
1e-5
def
test_full_loop_no_noise
(
self
):
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
num_inference_steps
,
eta
=
10
,
0.1
num_inference_steps
,
eta
=
10
,
0.0
num_trained_timesteps
=
len
(
scheduler
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
model
=
self
.
dummy_model
()
model
=
self
.
dummy_model
()
sample
=
self
.
dummy_sample_deter
sample
=
self
.
dummy_sample_deter
for
t
in
reversed
(
range
(
num_inference_steps
)):
scheduler
.
set_timesteps
(
num_inference_steps
)
residual
=
model
(
sample
,
inference_step_times
[
t
])
for
t
in
scheduler
.
timesteps
:
residual
=
model
(
sample
,
t
)
pred_prev_sample
=
scheduler
.
step
(
residual
,
sample
,
t
,
num_inference_steps
,
eta
)
variance
=
0
if
eta
>
0
:
noise
=
self
.
dummy_sample_deter
variance
=
scheduler
.
get_variance
(
t
,
num_inference_steps
)
**
(
0.5
)
*
eta
*
noise
sample
=
pred_prev_sample
+
variance
sample
=
scheduler
.
step
(
residual
,
t
,
sample
,
eta
)[
"prev_sample"
]
result_sum
=
np
.
sum
(
np
.
abs
(
sample
))
result_sum
=
np
.
sum
(
np
.
abs
(
sample
))
result_mean
=
np
.
mean
(
np
.
abs
(
sample
))
result_mean
=
np
.
mean
(
np
.
abs
(
sample
))
assert
abs
(
result_sum
.
item
()
-
270.6214
)
<
1e-2
assert
abs
(
result_sum
.
item
()
-
172.0067
)
<
1e-2
assert
abs
(
result_mean
.
item
()
-
0.
3524
)
<
1e-3
assert
abs
(
result_mean
.
item
()
-
0.
223967
)
<
1e-3
class
PNDMSchedulerTest
(
SchedulerCommonTest
):
class
PNDMSchedulerTest
(
SchedulerCommonTest
):
...
@@ -365,8 +397,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -365,8 +397,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_plms_mode
()
new_scheduler
.
set_plms_mode
()
output
=
scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)
[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)
[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
...
@@ -392,8 +424,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -392,8 +424,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_plms_mode
()
new_scheduler
.
set_plms_mode
()
output
=
scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)
[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
sample
,
time_step
,
**
kwargs
)
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)
[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
...
@@ -445,7 +477,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -445,7 +477,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler
.
set_plms_mode
()
scheduler
.
set_plms_mode
()
scheduler
.
step
(
self
.
dummy_sample
,
self
.
dummy_sample
,
1
,
50
)
scheduler
.
step
(
self
.
dummy_sample
,
1
,
self
.
dummy_sample
,
50
)
[
"prev_sample"
]
def
test_full_loop_no_noise
(
self
):
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
...
@@ -461,14 +493,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -461,14 +493,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
t_orig
=
prk_time_steps
[
t
]
t_orig
=
prk_time_steps
[
t
]
residual
=
model
(
sample
,
t_orig
)
residual
=
model
(
sample
,
t_orig
)
sample
=
scheduler
.
step_prk
(
residual
,
sample
,
t
,
num_inference_steps
)
sample
=
scheduler
.
step_prk
(
residual
,
t
,
sample
,
num_inference_steps
)
[
"prev_sample"
]
timesteps
=
scheduler
.
get_time_steps
(
num_inference_steps
)
timesteps
=
scheduler
.
get_time_steps
(
num_inference_steps
)
for
t
in
range
(
len
(
timesteps
)):
for
t
in
range
(
len
(
timesteps
)):
t_orig
=
timesteps
[
t
]
t_orig
=
timesteps
[
t
]
residual
=
model
(
sample
,
t_orig
)
residual
=
model
(
sample
,
t_orig
)
sample
=
scheduler
.
step_plms
(
residual
,
sample
,
t
,
num_inference_steps
)
sample
=
scheduler
.
step_plms
(
residual
,
t
,
sample
,
num_inference_steps
)
[
"prev_sample"
]
result_sum
=
np
.
sum
(
np
.
abs
(
sample
))
result_sum
=
np
.
sum
(
np
.
abs
(
sample
))
result_mean
=
np
.
mean
(
np
.
abs
(
sample
))
result_mean
=
np
.
mean
(
np
.
abs
(
sample
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment