Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
13e37cab
Commit
13e37cab
authored
Jul 20, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
760dcb1f
889aa600
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
377 additions
and
307 deletions
+377
-307
examples/README.md
examples/README.md
+14
-16
examples/train_unconditional.py
examples/train_unconditional.py
+60
-66
src/diffusers/models/vae.py
src/diffusers/models/vae.py
+1
-1
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+15
-0
src/diffusers/pipelines/ddim/pipeline_ddim.py
src/diffusers/pipelines/ddim/pipeline_ddim.py
+6
-1
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+6
-1
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
...s/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+44
-28
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
...tent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+14
-18
src/diffusers/pipelines/pndm/pipeline_pndm.py
src/diffusers/pipelines/pndm/pipeline_pndm.py
+13
-11
src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
...diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
+11
-6
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+18
-39
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+72
-68
tests/test_scheduler.py
tests/test_scheduler.py
+103
-52
No files found.
examples/README.md
View file @
13e37cab
...
@@ -5,18 +5,17 @@
...
@@ -5,18 +5,17 @@
The command to train a DDPM UNet model on the Oxford Flowers dataset:
The command to train a DDPM UNet model on the Oxford Flowers dataset:
```
bash
```
bash
python
-m
torch.distributed.launch
\
accelerate launch train_unconditional.py
\
--nproc_per_node
4
\
train_unconditional.py
\
--dataset
=
"huggan/flowers-102-categories"
\
--dataset
=
"huggan/flowers-102-categories"
\
--resolution
=
64
\
--resolution
=
64
\
--output_dir
=
"flowers-
ddpm
"
\
--output_dir
=
"
ddpm-ema-
flowers-
64
"
\
--batch_size
=
16
\
--
train_
batch_size
=
16
\
--num_epochs
=
100
\
--num_epochs
=
100
\
--gradient_accumulation_steps
=
1
\
--gradient_accumulation_steps
=
1
\
--lr
=
1e-4
\
--learning_rate
=
1e-4
\
--warmup_steps
=
500
\
--lr_warmup_steps
=
500
\
--mixed_precision
=
no
--mixed_precision
=
no
\
--push_to_hub
```
```
A full training run takes 2 hours on 4xV100 GPUs.
A full training run takes 2 hours on 4xV100 GPUs.
...
@@ -29,18 +28,17 @@ A full training run takes 2 hours on 4xV100 GPUs.
...
@@ -29,18 +28,17 @@ A full training run takes 2 hours on 4xV100 GPUs.
The command to train a DDPM UNet model on the Pokemon dataset:
The command to train a DDPM UNet model on the Pokemon dataset:
```
bash
```
bash
python
-m
torch.distributed.launch
\
accelerate launch train_unconditional.py
\
--nproc_per_node
4
\
train_unconditional.py
\
--dataset
=
"huggan/pokemon"
\
--dataset
=
"huggan/pokemon"
\
--resolution
=
64
\
--resolution
=
64
\
--output_dir
=
"pokemon-
ddpm
"
\
--output_dir
=
"
ddpm-ema-
pokemon-
64
"
\
--batch_size
=
16
\
--
train_
batch_size
=
16
\
--num_epochs
=
100
\
--num_epochs
=
100
\
--gradient_accumulation_steps
=
1
\
--gradient_accumulation_steps
=
1
\
--lr
=
1e-4
\
--learning_rate
=
1e-4
\
--warmup_steps
=
500
\
--lr_warmup_steps
=
500
\
--mixed_precision
=
no
--mixed_precision
=
no
\
--push_to_hub
```
```
A full training run takes 2 hours on 4xV100 GPUs.
A full training run takes 2 hours on 4xV100 GPUs.
...
...
examples/train_unconditional.py
View file @
13e37cab
...
@@ -4,10 +4,10 @@ import os
...
@@ -4,10 +4,10 @@ import os
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
accelerate
import
Accelerator
,
DistributedDataParallelKwargs
from
accelerate
import
Accelerator
from
accelerate.logging
import
get_logger
from
accelerate.logging
import
get_logger
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
diffusers
import
DD
I
MPipeline
,
DD
I
MScheduler
,
UNetModel
from
diffusers
import
DD
P
MPipeline
,
DD
P
MScheduler
,
UNet
Unconditional
Model
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
...
@@ -27,25 +27,37 @@ logger = get_logger(__name__)
...
@@ -27,25 +27,37 @@ logger = get_logger(__name__)
def
main
(
args
):
def
main
(
args
):
ddp_unused_params
=
DistributedDataParallelKwargs
(
find_unused_parameters
=
True
)
logging_dir
=
os
.
path
.
join
(
args
.
output_dir
,
args
.
logging_dir
)
logging_dir
=
os
.
path
.
join
(
args
.
output_dir
,
args
.
logging_dir
)
accelerator
=
Accelerator
(
accelerator
=
Accelerator
(
mixed_precision
=
args
.
mixed_precision
,
mixed_precision
=
args
.
mixed_precision
,
log_with
=
"tensorboard"
,
log_with
=
"tensorboard"
,
logging_dir
=
logging_dir
,
logging_dir
=
logging_dir
,
kwargs_handlers
=
[
ddp_unused_params
],
)
)
model
=
UNetModel
(
model
=
UNetUnconditionalModel
(
attn_resolutions
=
(
16
,),
image_size
=
args
.
resolution
,
ch
=
128
,
in_channels
=
3
,
ch_mult
=
(
1
,
2
,
4
,
8
),
out_channels
=
3
,
dropout
=
0.0
,
num_res_blocks
=
2
,
num_res_blocks
=
2
,
resamp_with_conv
=
True
,
block_channels
=
(
128
,
128
,
256
,
256
,
512
,
512
),
resolution
=
args
.
resolution
,
down_blocks
=
(
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
,
"UNetResDownBlock2D"
,
"UNetResAttnDownBlock2D"
,
"UNetResDownBlock2D"
,
),
up_blocks
=
(
"UNetResUpBlock2D"
,
"UNetResAttnUpBlock2D"
,
"UNetResUpBlock2D"
,
"UNetResUpBlock2D"
,
"UNetResUpBlock2D"
,
"UNetResUpBlock2D"
,
),
)
)
noise_scheduler
=
DD
I
MScheduler
(
timesteps
=
1000
,
tensor_format
=
"pt"
)
noise_scheduler
=
DD
P
MScheduler
(
num_train_
timesteps
=
1000
,
tensor_format
=
"pt"
)
optimizer
=
torch
.
optim
.
AdamW
(
optimizer
=
torch
.
optim
.
AdamW
(
model
.
parameters
(),
model
.
parameters
(),
lr
=
args
.
learning_rate
,
lr
=
args
.
learning_rate
,
...
@@ -92,19 +104,6 @@ def main(args):
...
@@ -92,19 +104,6 @@ def main(args):
run
=
os
.
path
.
split
(
__file__
)[
-
1
].
split
(
"."
)[
0
]
run
=
os
.
path
.
split
(
__file__
)[
-
1
].
split
(
"."
)[
0
]
accelerator
.
init_trackers
(
run
)
accelerator
.
init_trackers
(
run
)
# Train!
is_distributed
=
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
if
is_distributed
else
1
total_train_batch_size
=
args
.
train_batch_size
*
args
.
gradient_accumulation_steps
*
world_size
max_steps
=
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
*
args
.
num_epochs
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
f
" Num examples =
{
len
(
train_dataloader
.
dataset
)
}
"
)
logger
.
info
(
f
" Num Epochs =
{
args
.
num_epochs
}
"
)
logger
.
info
(
f
" Instantaneous batch size per device =
{
args
.
train_batch_size
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel, distributed & accumulation) =
{
total_train_batch_size
}
"
)
logger
.
info
(
f
" Gradient Accumulation steps =
{
args
.
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total optimization steps =
{
max_steps
}
"
)
global_step
=
0
global_step
=
0
for
epoch
in
range
(
args
.
num_epochs
):
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
model
.
train
()
...
@@ -112,45 +111,37 @@ def main(args):
...
@@ -112,45 +111,37 @@ def main(args):
progress_bar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
progress_bar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"input"
]
clean_images
=
batch
[
"input"
]
noise_samples
=
torch
.
randn
(
clean_images
.
shape
).
to
(
clean_images
.
device
)
# Sample noise that we'll add to the images
noise
=
torch
.
randn
(
clean_images
.
shape
).
to
(
clean_images
.
device
)
bsz
=
clean_images
.
shape
[
0
]
bsz
=
clean_images
.
shape
[
0
]
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
timesteps
,
(
bsz
,),
device
=
clean_images
.
device
).
long
()
# Sample a random timestep for each image
timesteps
=
torch
.
randint
(
0
,
noise_scheduler
.
num_train_timesteps
,
(
bsz
,),
device
=
clean_images
.
device
).
long
()
#
a
dd noise
on
to the clean images according to the noise magnitude at each timestep
#
A
dd noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
# (this is the forward diffusion process)
noisy_images
=
noise_scheduler
.
add_noise
(
clean_images
,
noise_samples
,
timesteps
)
noisy_images
=
noise_scheduler
.
add_noise
(
clean_images
,
noise
,
timesteps
)
if
step
%
args
.
gradient_accumulation_steps
!=
0
:
with
accelerator
.
accumulate
(
model
):
with
accelerator
.
no_sync
(
model
):
# Predict the noise residual
output
=
model
(
noisy_images
,
timesteps
)
noise_pred
=
model
(
noisy_images
,
timesteps
)[
"sample"
]
# predict the noise residual
loss
=
F
.
mse_loss
(
noise_pred
,
noise
)
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
accelerator
.
backward
(
loss
)
else
:
output
=
model
(
noisy_images
,
timesteps
)
# predict the noise residual
loss
=
F
.
mse_loss
(
output
,
noise_samples
)
loss
=
loss
/
args
.
gradient_accumulation_steps
accelerator
.
backward
(
loss
)
accelerator
.
backward
(
loss
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
accelerator
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
optimizer
.
step
()
optimizer
.
step
()
lr_scheduler
.
step
()
lr_scheduler
.
step
()
if
args
.
use_ema
:
ema_model
.
step
(
model
)
ema_model
.
step
(
model
)
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
progress_bar
.
update
(
1
)
progress_bar
.
update
(
1
)
progress_bar
.
set_postfix
(
logs
=
{
"loss"
:
loss
.
detach
().
item
(),
"lr"
:
lr_scheduler
.
get_last_lr
()[
0
],
"step"
:
global_step
}
loss
=
loss
.
detach
().
item
(),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
],
ema_decay
=
ema_model
.
decay
if
args
.
use_ema
:
)
logs
[
"ema_decay"
]
=
ema_model
.
decay
accelerator
.
log
(
progress_bar
.
set_postfix
(
**
logs
)
{
accelerator
.
log
(
logs
,
step
=
global_step
)
"train_loss"
:
loss
.
detach
().
item
(),
"epoch"
:
epoch
,
"ema_decay"
:
ema_model
.
decay
,
"step"
:
global_step
,
},
step
=
global_step
,
)
global_step
+=
1
global_step
+=
1
progress_bar
.
close
()
progress_bar
.
close
()
...
@@ -159,14 +150,14 @@ def main(args):
...
@@ -159,14 +150,14 @@ def main(args):
# Generate a sample image for visual inspection
# Generate a sample image for visual inspection
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pipeline
=
DD
I
MPipeline
(
pipeline
=
DD
P
MPipeline
(
unet
=
accelerator
.
unwrap_model
(
ema_model
.
averaged_model
),
unet
=
accelerator
.
unwrap_model
(
ema_model
.
averaged_model
if
args
.
use_ema
else
model
),
noise_
scheduler
=
noise_scheduler
,
scheduler
=
noise_scheduler
,
)
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
# run pipeline in inference (sample random noise and denoise)
# run pipeline in inference (sample random noise and denoise)
images
=
pipeline
(
generator
=
generator
,
batch_size
=
args
.
eval_batch_size
,
num_inference_steps
=
50
)
images
=
pipeline
(
generator
=
generator
,
batch_size
=
args
.
eval_batch_size
)
# denormalize the images and save to tensorboard
# denormalize the images and save to tensorboard
images_processed
=
(
images
.
cpu
()
+
1.0
)
*
127.5
images_processed
=
(
images
.
cpu
()
+
1.0
)
*
127.5
...
@@ -174,6 +165,7 @@ def main(args):
...
@@ -174,6 +165,7 @@ def main(args):
accelerator
.
trackers
[
0
].
writer
.
add_images
(
"test_samples"
,
images_processed
,
epoch
)
accelerator
.
trackers
[
0
].
writer
.
add_images
(
"test_samples"
,
images_processed
,
epoch
)
if
epoch
%
args
.
save_model_epochs
==
0
or
epoch
==
args
.
num_epochs
-
1
:
# save the model
# save the model
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
push_to_hub
(
args
,
pipeline
,
repo
,
commit_message
=
f
"Epoch
{
epoch
}
"
,
blocking
=
False
)
...
@@ -188,12 +180,13 @@ if __name__ == "__main__":
...
@@ -188,12 +180,13 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple example of a training script."
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"huggan/flowers-102-categories"
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"huggan/flowers-102-categories"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"ddpm-
model
"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
"ddpm-
flowers-64
"
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--resolution"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--train_batch_size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--train_batch_size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--eval_batch_size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--eval_batch_size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--save_model_epochs"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--learning_rate"
,
type
=
float
,
default
=
1e-4
)
parser
.
add_argument
(
"--lr_scheduler"
,
type
=
str
,
default
=
"cosine"
)
parser
.
add_argument
(
"--lr_scheduler"
,
type
=
str
,
default
=
"cosine"
)
...
@@ -202,6 +195,7 @@ if __name__ == "__main__":
...
@@ -202,6 +195,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--adam_beta2"
,
type
=
float
,
default
=
0.999
)
parser
.
add_argument
(
"--adam_beta2"
,
type
=
float
,
default
=
0.999
)
parser
.
add_argument
(
"--adam_weight_decay"
,
type
=
float
,
default
=
1e-6
)
parser
.
add_argument
(
"--adam_weight_decay"
,
type
=
float
,
default
=
1e-6
)
parser
.
add_argument
(
"--adam_epsilon"
,
type
=
float
,
default
=
1e-3
)
parser
.
add_argument
(
"--adam_epsilon"
,
type
=
float
,
default
=
1e-3
)
parser
.
add_argument
(
"--use_ema"
,
action
=
"store_true"
,
default
=
True
)
parser
.
add_argument
(
"--ema_inv_gamma"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--ema_inv_gamma"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--ema_power"
,
type
=
float
,
default
=
3
/
4
)
parser
.
add_argument
(
"--ema_power"
,
type
=
float
,
default
=
3
/
4
)
parser
.
add_argument
(
"--ema_max_decay"
,
type
=
float
,
default
=
0.9999
)
parser
.
add_argument
(
"--ema_max_decay"
,
type
=
float
,
default
=
0.9999
)
...
...
src/diffusers/models/vae.py
View file @
13e37cab
...
@@ -145,7 +145,7 @@ class Decoder(nn.Module):
...
@@ -145,7 +145,7 @@ class Decoder(nn.Module):
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
print
(
"Working with z of shape {} = {} dimensions."
.
format
(
self
.
z_shape
,
np
.
prod
(
self
.
z_shape
)))
#
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
# z to block_in
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
...
...
src/diffusers/pipeline_utils.py
View file @
13e37cab
...
@@ -19,6 +19,7 @@ import os
...
@@ -19,6 +19,7 @@ import os
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
PIL
import
Image
from
.configuration_utils
import
ConfigMixin
from
.configuration_utils
import
ConfigMixin
from
.utils
import
DIFFUSERS_CACHE
,
logging
from
.utils
import
DIFFUSERS_CACHE
,
logging
...
@@ -120,6 +121,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -120,6 +121,7 @@ class DiffusionPipeline(ConfigMixin):
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
proxies
=
kwargs
.
pop
(
"proxies"
,
None
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
local_files_only
=
kwargs
.
pop
(
"local_files_only"
,
False
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
use_auth_token
=
kwargs
.
pop
(
"use_auth_token"
,
None
)
revision
=
kwargs
.
pop
(
"revision"
,
None
)
# 1. Download the checkpoints and configs
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
# use snapshot download here to get it working from from_pretrained
...
@@ -131,6 +133,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -131,6 +133,7 @@ class DiffusionPipeline(ConfigMixin):
proxies
=
proxies
,
proxies
=
proxies
,
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
)
)
else
:
else
:
cached_folder
=
pretrained_model_name_or_path
cached_folder
=
pretrained_model_name_or_path
...
@@ -187,3 +190,15 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -187,3 +190,15 @@ class DiffusionPipeline(ConfigMixin):
# 5. Instantiate the pipeline
# 5. Instantiate the pipeline
model
=
pipeline_class
(
**
init_kwargs
)
model
=
pipeline_class
(
**
init_kwargs
)
return
model
return
model
@
staticmethod
def
numpy_to_pil
(
images
):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if
images
.
ndim
==
3
:
images
=
images
[
None
,
...]
images
=
(
images
*
255
).
round
().
astype
(
"uint8"
)
pil_images
=
[
Image
.
fromarray
(
image
)
for
image
in
images
]
return
pil_images
src/diffusers/pipelines/ddim/pipeline_ddim.py
View file @
13e37cab
...
@@ -28,7 +28,7 @@ class DDIMPipeline(DiffusionPipeline):
...
@@ -28,7 +28,7 @@ class DDIMPipeline(DiffusionPipeline):
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
,
output_type
=
"pil"
):
# eta corresponds to η in paper and should be between [0, 1]
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
@@ -53,4 +53,9 @@ class DDIMPipeline(DiffusionPipeline):
...
@@ -53,4 +53,9 @@ class DDIMPipeline(DiffusionPipeline):
# do x_t -> x_t-1
# do x_t -> x_t-1
image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
,
eta
)[
"prev_sample"
]
image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
,
eta
)[
"prev_sample"
]
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
return
{
"sample"
:
image
}
return
{
"sample"
:
image
}
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
13e37cab
...
@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline):
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
output_type
=
"pil"
):
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
@@ -54,4 +54,9 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -54,4 +54,9 @@ class DDPMPipeline(DiffusionPipeline):
# 3. set current image to prev_image: x_t -> x_t-1
# 3. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
image
=
pred_prev_image
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
return
{
"sample"
:
image
}
return
{
"sample"
:
image
}
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
View file @
13e37cab
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.utils.checkpoint
import
torch.utils.checkpoint
import
tqdm
from
tqdm.auto
import
tqdm
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_outputs
import
BaseModelOutput
from
transformers.modeling_outputs
import
BaseModelOutput
...
@@ -30,51 +30,67 @@ class LatentDiffusionPipeline(DiffusionPipeline):
...
@@ -30,51 +30,67 @@ class LatentDiffusionPipeline(DiffusionPipeline):
eta
=
0.0
,
eta
=
0.0
,
guidance_scale
=
1.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
,
num_inference_steps
=
50
,
output_type
=
"pil"
,
):
):
# eta corresponds to η in paper and should be between [0, 1]
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
batch_size
=
len
(
prompt
)
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
# get unconditional embeddings for classifier free guid
e
nce
# get unconditional embeddings for classifier free guid
a
nce
if
guidance_scale
!=
1.0
:
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
uncond_input
=
self
.
tokenizer
([
""
]
*
batch_size
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
)
torch_device
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
.
to
(
torch_device
))
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)
# get text embedding
# get
prompt
text embedding
s
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
)
.
to
(
torch_device
)
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)
text_embedding
s
=
self
.
bert
(
text_input
.
input_ids
.
to
(
torch_device
)
)
image
=
torch
.
randn
(
latents
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
generator
=
generator
,
).
to
(
torch_device
)
)
latents
=
latents
.
to
(
torch_device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
for
t
in
tqdm
.
tqdm
(
self
.
scheduler
.
timesteps
):
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
# 1. predict noise residual
if
guidance_scale
==
1.0
:
pred_noise_t
=
self
.
unet
(
image
,
t
,
encoder_hidden_states
=
text_embedding
)
# guidance_scale of 1 means no guidance
latents_input
=
latents
context
=
text_embeddings
else
:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input
=
torch
.
cat
([
latents
]
*
2
)
context
=
torch
.
cat
([
uncond_embeddings
,
text_embeddings
])
# predict the noise residual
noise_pred
=
self
.
unet
(
latents_input
,
t
,
encoder_hidden_states
=
context
)[
"sample"
]
# perform guidance
if
guidance_scale
!=
1.0
:
noise_pred_uncond
,
noise_prediction_text
=
noise_pred
.
chunk
(
2
)
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_prediction_text
-
noise_pred_uncond
)
if
isinstance
(
pred_noise_t
,
dict
):
# compute the previous noisy sample x_t -> x_t-1
pred_noise_t
=
pred_noise_t
[
"
sample"
]
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
eta
)[
"prev_
sample"
]
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# scale and decode the image latents with vae
# do x_t -> x_t-1
latents
=
1
/
0.18215
*
latents
image
=
self
.
scheduler
.
step
(
pred_noise_t
,
t
,
image
,
eta
)[
"prev_sample"
]
image
=
self
.
vqvae
.
decode
(
latents
)
# scale and decode image with vae
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
1
/
0.18215
*
image
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
i
mage
=
self
.
vqvae
.
decode
(
image
)
i
f
output_type
==
"pil"
:
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
image
=
self
.
numpy_to_pil
(
image
)
return
image
return
{
"sample"
:
image
}
################################################################################
################################################################################
...
...
src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
View file @
13e37cab
...
@@ -13,12 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
...
@@ -13,12 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
def
__call__
(
self
,
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
,
output_type
=
"pil"
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
,
):
):
# eta corresponds to η in paper and should be between [0, 1]
# eta corresponds to η in paper and should be between [0, 1]
...
@@ -28,25 +23,26 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
...
@@ -28,25 +23,26 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
image
=
torch
.
randn
(
latents
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
generator
=
generator
,
generator
=
generator
,
).
to
(
torch_device
)
)
latents
=
latents
.
to
(
torch_device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
for
t
in
tqdm
(
self
.
scheduler
.
timesteps
):
with
torch
.
no_grad
():
# predict the noise residual
model_output
=
self
.
unet
(
image
,
t
)
noise_prediction
=
self
.
unet
(
latents
,
t
)[
"sample"
]
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_prediction
,
t
,
latents
,
eta
)[
"prev_sample"
]
if
isinstance
(
model_output
,
dict
):
# decode the image latents with the VAE
model_output
=
model_output
[
"sample"
]
image
=
self
.
vqvae
.
decode
(
latents
)
# 2. predict previous mean of image x_t-1 and add variance depending on eta
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
# do x_t -> x_t-1
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
,
eta
)[
"prev_sample"
]
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
# decode image with vae
with
torch
.
no_grad
():
image
=
self
.
vqvae
.
decode
(
image
)
return
{
"sample"
:
image
}
return
{
"sample"
:
image
}
src/diffusers/pipelines/pndm/pipeline_pndm.py
View file @
13e37cab
...
@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline):
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
self
.
register_modules
(
unet
=
unet
,
scheduler
=
scheduler
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
,
output_type
=
"pil"
):
# For more information on the sampling method you can take a look at Algorithm 2 of
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
if
torch_device
is
None
:
if
torch_device
is
None
:
...
@@ -43,18 +43,20 @@ class PNDMPipeline(DiffusionPipeline):
...
@@ -43,18 +43,20 @@ class PNDMPipeline(DiffusionPipeline):
)
)
image
=
image
.
to
(
torch_device
)
image
=
image
.
to
(
torch_device
)
prk_time_steps
=
self
.
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
for
t
in
tqdm
(
range
(
len
(
prk_time_steps
))):
for
i
,
t
in
enumerate
(
tqdm
(
self
.
scheduler
.
prk_timesteps
)):
t_orig
=
prk_time_steps
[
t
]
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
model_output
=
self
.
unet
(
image
,
t_orig
)[
"sample"
]
image
=
self
.
scheduler
.
step_prk
(
model_output
,
t
,
image
,
num_inference_steps
)[
"prev_sample"
]
image
=
self
.
scheduler
.
step_prk
(
model_output
,
i
,
image
,
num_inference_steps
)[
"prev_sample"
]
timesteps
=
self
.
scheduler
.
get_time_steps
(
num_inference_steps
)
for
i
,
t
in
enumerate
(
tqdm
(
self
.
scheduler
.
plms_timesteps
)):
for
t
in
tqdm
(
range
(
len
(
timesteps
))):
model_output
=
self
.
unet
(
image
,
t
)[
"sample"
]
t_orig
=
timesteps
[
t
]
model_output
=
self
.
unet
(
image
,
t_orig
)[
"sample"
]
image
=
self
.
scheduler
.
step_plms
(
model_output
,
t
,
image
,
num_inference_steps
)[
"prev_sample"
]
image
=
self
.
scheduler
.
step_plms
(
model_output
,
i
,
image
,
num_inference_steps
)[
"prev_sample"
]
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
"pil"
:
image
=
self
.
numpy_to_pil
(
image
)
return
{
"sample"
:
image
}
return
{
"sample"
:
image
}
src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
View file @
13e37cab
...
@@ -2,15 +2,16 @@
...
@@ -2,15 +2,16 @@
import
torch
import
torch
from
diffusers
import
DiffusionPipeline
from
diffusers
import
DiffusionPipeline
from
tqdm.auto
import
tqdm
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class
ScoreSdeVePipeline
(
DiffusionPipeline
):
class
ScoreSdeVePipeline
(
DiffusionPipeline
):
def
__init__
(
self
,
model
,
scheduler
):
def
__init__
(
self
,
model
,
scheduler
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
model
=
model
,
scheduler
=
scheduler
)
self
.
register_modules
(
model
=
model
,
scheduler
=
scheduler
)
def
__call__
(
self
,
num_inference_steps
=
2000
,
generator
=
None
):
@
torch
.
no_grad
()
def
__call__
(
self
,
num_inference_steps
=
2000
,
generator
=
None
,
output_type
=
"pil"
):
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
img_size
=
self
.
model
.
config
.
image_size
img_size
=
self
.
model
.
config
.
image_size
...
@@ -24,11 +25,10 @@ class ScoreSdeVePipeline(DiffusionPipeline):
...
@@ -24,11 +25,10 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
)
self
.
scheduler
.
set_sigmas
(
num_inference_steps
)
self
.
scheduler
.
set_sigmas
(
num_inference_steps
)
for
i
,
t
in
enumerate
(
self
.
scheduler
.
timesteps
):
for
i
,
t
in
tqdm
(
enumerate
(
self
.
scheduler
.
timesteps
)
)
:
sigma_t
=
self
.
scheduler
.
sigmas
[
i
]
*
torch
.
ones
(
shape
[
0
],
device
=
device
)
sigma_t
=
self
.
scheduler
.
sigmas
[
i
]
*
torch
.
ones
(
shape
[
0
],
device
=
device
)
for
_
in
range
(
self
.
scheduler
.
correct_steps
):
for
_
in
range
(
self
.
scheduler
.
correct_steps
):
with
torch
.
no_grad
():
model_output
=
self
.
model
(
sample
,
sigma_t
)
model_output
=
self
.
model
(
sample
,
sigma_t
)
if
isinstance
(
model_output
,
dict
):
if
isinstance
(
model_output
,
dict
):
...
@@ -45,4 +45,9 @@ class ScoreSdeVePipeline(DiffusionPipeline):
...
@@ -45,4 +45,9 @@ class ScoreSdeVePipeline(DiffusionPipeline):
output
=
self
.
scheduler
.
step_pred
(
model_output
,
t
,
sample
)
output
=
self
.
scheduler
.
step_pred
(
model_output
,
t
,
sample
)
sample
,
sample_mean
=
output
[
"prev_sample"
],
output
[
"prev_sample_mean"
]
sample
,
sample_mean
=
output
[
"prev_sample"
],
output
[
"prev_sample_mean"
]
return
sample_mean
sample
=
sample
.
clamp
(
0
,
1
)
sample
=
sample
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
if
output_type
==
"pil"
:
sample
=
self
.
numpy_to_pil
(
sample
)
return
{
"sample"
:
sample
}
src/diffusers/schedulers/scheduling_pndm.py
View file @
13e37cab
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import
math
import
math
import
pdb
from
typing
import
Union
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -71,8 +72,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -71,8 +72,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self
.
one
=
np
.
array
(
1.0
)
self
.
one
=
np
.
array
(
1.0
)
self
.
set_format
(
tensor_format
=
tensor_format
)
# For now we only support F-PNDM, i.e. the runge-kutta method
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2.
# mainly at formula (9), (12), (13) and the Algorithm 2.
...
@@ -82,49 +81,29 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -82,49 +81,29 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self
.
cur_model_output
=
0
self
.
cur_model_output
=
0
self
.
cur_sample
=
None
self
.
cur_sample
=
None
self
.
ets
=
[]
self
.
ets
=
[]
self
.
prk_time_steps
=
{}
self
.
time_steps
=
{}
self
.
set_prk_mode
()
def
get_prk_time_steps
(
self
,
num_inference_steps
):
# setable values
if
num_inference_steps
in
self
.
prk_time_steps
:
self
.
num_inference_steps
=
None
return
self
.
prk_time_steps
[
num_inference_steps
]
self
.
timesteps
=
np
.
arange
(
0
,
num_train_timesteps
)[::
-
1
].
copy
()
self
.
prk_timesteps
=
None
self
.
plms_timesteps
=
None
self
.
tensor_format
=
tensor_format
self
.
set_format
(
tensor_format
=
tensor_format
)
inference_step_times
=
list
(
def
set_timesteps
(
self
,
num_inference_steps
):
self
.
num_inference_steps
=
num_inference_steps
self
.
timesteps
=
list
(
range
(
0
,
self
.
config
.
num_train_timesteps
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
)
range
(
0
,
self
.
config
.
num_train_timesteps
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
)
)
)
prk_time_steps
=
np
.
array
(
inference_step_time
s
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
prk_time_steps
=
np
.
array
(
self
.
timestep
s
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
np
.
array
([
0
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
)
self
.
prk_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
prk_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
prk_timesteps
=
list
(
reversed
(
prk_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
self
.
plms_timesteps
=
list
(
reversed
(
self
.
timesteps
[:
-
3
]))
return
self
.
prk_time_steps
[
num_inference_steps
]
def
get_time_steps
(
self
,
num_inference_steps
):
if
num_inference_steps
in
self
.
time_steps
:
return
self
.
time_steps
[
num_inference_steps
]
inference_step_times
=
list
(
range
(
0
,
self
.
config
.
num_train_timesteps
,
self
.
config
.
num_train_timesteps
//
num_inference_steps
)
)
self
.
time_steps
[
num_inference_steps
]
=
list
(
reversed
(
inference_step_times
[:
-
3
]))
return
self
.
time_steps
[
num_inference_steps
]
def
set_prk_mode
(
self
):
self
.
mode
=
"prk"
def
set_plms_mode
(
self
):
self
.
mode
=
"plms"
def
step
(
self
,
*
args
,
**
kwargs
):
if
self
.
mode
==
"prk"
:
return
self
.
step_prk
(
*
args
,
**
kwargs
)
if
self
.
mode
==
"plms"
:
return
self
.
step_plms
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"mode
{
self
.
mode
}
does not exist."
)
self
.
set_format
(
tensor_format
=
self
.
tensor_format
)
def
step_prk
(
def
step_prk
(
self
,
self
,
...
@@ -138,7 +117,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -138,7 +117,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
solution to the differential equation.
solution to the differential equation.
"""
"""
t
=
timestep
t
=
timestep
prk_time_steps
=
self
.
get_
prk_time
_
steps
(
num_inference_steps
)
prk_time_steps
=
self
.
prk_timesteps
t_orig
=
prk_time_steps
[
t
//
4
*
4
]
t_orig
=
prk_time_steps
[
t
//
4
*
4
]
t_orig_prev
=
prk_time_steps
[
min
(
t
+
1
,
len
(
prk_time_steps
)
-
1
)]
t_orig_prev
=
prk_time_steps
[
min
(
t
+
1
,
len
(
prk_time_steps
)
-
1
)]
...
@@ -180,7 +159,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -180,7 +159,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information."
"for more information."
)
)
timesteps
=
self
.
get
_time
_
steps
(
num_inference_steps
)
timesteps
=
self
.
plms
_timesteps
t_orig
=
timesteps
[
t
]
t_orig
=
timesteps
[
t
]
t_orig_prev
=
timesteps
[
min
(
t
+
1
,
len
(
timesteps
)
-
1
)]
t_orig_prev
=
timesteps
[
min
(
t
+
1
,
len
(
timesteps
)
-
1
)]
...
...
tests/test_modeling_utils.py
View file @
13e37cab
...
@@ -18,11 +18,11 @@ import inspect
...
@@ -18,11 +18,11 @@ import inspect
import
math
import
math
import
tempfile
import
tempfile
import
unittest
import
unittest
from
atexit
import
register
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
PIL
from
diffusers
import
UNetConditionalModel
# noqa: F401 TODO(Patrick) - need to write tests with it
from
diffusers
import
UNetConditionalModel
# noqa: F401 TODO(Patrick) - need to write tests with it
from
diffusers
import
(
from
diffusers
import
(
AutoencoderKL
,
AutoencoderKL
,
...
@@ -704,11 +704,11 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -704,11 +704,11 @@ class PipelineTesterMixin(unittest.TestCase):
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)[
"sample"
]
image
=
ddpm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
generator
=
generator
.
manual_seed
(
0
)
generator
=
generator
.
manual_seed
(
0
)
new_image
=
new_ddpm
(
generator
=
generator
)[
"sample"
]
new_image
=
new_ddpm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
assert
np
.
abs
(
image
-
new_image
).
sum
()
<
1e-5
,
"Models don't give the same forward pass"
@
slow
@
slow
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
...
@@ -722,11 +722,32 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -722,11 +722,32 @@ class PipelineTesterMixin(unittest.TestCase):
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)[
"sample"
]
image
=
ddpm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
generator
=
generator
.
manual_seed
(
0
)
generator
=
generator
.
manual_seed
(
0
)
new_image
=
ddpm_from_hub
(
generator
=
generator
)[
"sample"
]
new_image
=
ddpm_from_hub
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
assert
np
.
abs
(
image
-
new_image
).
sum
()
<
1e-5
,
"Models don't give the same forward pass"
@
slow
def
test_output_format
(
self
):
model_path
=
"google/ddpm-cifar10-32"
pipe
=
DDIMPipeline
.
from_pretrained
(
model_path
)
generator
=
torch
.
manual_seed
(
0
)
images
=
pipe
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
assert
images
.
shape
==
(
1
,
32
,
32
,
3
)
assert
isinstance
(
images
,
np
.
ndarray
)
images
=
pipe
(
generator
=
generator
,
output_type
=
"pil"
)[
"sample"
]
assert
isinstance
(
images
,
list
)
assert
len
(
images
)
==
1
assert
isinstance
(
images
[
0
],
PIL
.
Image
.
Image
)
# use PIL by default
images
=
pipe
(
generator
=
generator
)[
"sample"
]
assert
isinstance
(
images
,
list
)
assert
isinstance
(
images
[
0
],
PIL
.
Image
.
Image
)
@
slow
@
slow
def
test_ddpm_cifar10
(
self
):
def
test_ddpm_cifar10
(
self
):
...
@@ -739,15 +760,13 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -739,15 +760,13 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
=
DDPMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)[
"sample"
]
image
=
ddpm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
3
:
,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
assert
image
.
shape
==
(
1
,
32
,
32
,
3
)
expected_slice
=
torch
.
tensor
(
expected_slice
=
np
.
array
([
0.41995
,
0.35885
,
0.19385
,
0.38475
,
0.3382
,
0.2647
,
0.41545
,
0.3582
,
0.33845
])
[
-
0.1601
,
-
0.2823
,
-
0.6123
,
-
0.2305
,
-
0.3236
,
-
0.4706
,
-
0.1691
,
-
0.2836
,
-
0.3231
]
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
@
slow
def
test_ddim_lsun
(
self
):
def
test_ddim_lsun
(
self
):
...
@@ -759,15 +778,13 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -759,15 +778,13 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddpm
(
generator
=
generator
)[
"sample"
]
image
=
ddpm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
3
:
,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
torch
.
tensor
(
expected_slice
=
np
.
array
([
0.00605
,
0.0201
,
0.0344
,
0.00235
,
0.00185
,
0.00025
,
0.00215
,
0.0
,
0.00685
])
[
-
0.9879
,
-
0.9598
,
-
0.9312
,
-
0.9953
,
-
0.9963
,
-
0.9995
,
-
0.9957
,
-
1.0000
,
-
0.9863
]
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
@
slow
def
test_ddim_cifar10
(
self
):
def
test_ddim_cifar10
(
self
):
...
@@ -779,15 +796,13 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -779,15 +796,13 @@ class PipelineTesterMixin(unittest.TestCase):
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)[
"sample"
]
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
,
output_type
=
"numpy"
)[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
3
:
,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
assert
image
.
shape
==
(
1
,
32
,
32
,
3
)
expected_slice
=
torch
.
tensor
(
expected_slice
=
np
.
array
([
0.17235
,
0.16175
,
0.16005
,
0.16255
,
0.1497
,
0.1513
,
0.15045
,
0.1442
,
0.1453
])
[
-
0.6553
,
-
0.6765
,
-
0.6799
,
-
0.6749
,
-
0.7006
,
-
0.6974
,
-
0.6991
,
-
0.7116
,
-
0.7094
]
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
@
slow
def
test_pndm_cifar10
(
self
):
def
test_pndm_cifar10
(
self
):
...
@@ -798,15 +813,13 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -798,15 +813,13 @@ class PipelineTesterMixin(unittest.TestCase):
pndm
=
PNDMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
pndm
=
PNDMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
pndm
(
generator
=
generator
)[
"sample"
]
image
=
pndm
(
generator
=
generator
,
output_type
=
"numpy"
)[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
3
:
,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
assert
image
.
shape
==
(
1
,
32
,
32
,
3
)
expected_slice
=
torch
.
tensor
(
expected_slice
=
np
.
array
([
0.1564
,
0.14645
,
0.1406
,
0.14715
,
0.12425
,
0.14045
,
0.13115
,
0.12175
,
0.125
])
[
-
0.6872
,
-
0.7071
,
-
0.7188
,
-
0.7057
,
-
0.7515
,
-
0.7191
,
-
0.7377
,
-
0.7565
,
-
0.7500
]
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
@
slow
def
test_ldm_text2img
(
self
):
def
test_ldm_text2img
(
self
):
...
@@ -814,13 +827,15 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -814,13 +827,15 @@ class PipelineTesterMixin(unittest.TestCase):
prompt
=
"A painting of a squirrel eating a burger"
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
([
prompt
],
generator
=
generator
,
num_inference_steps
=
20
)
image
=
ldm
([
prompt
],
generator
=
generator
,
guidance_scale
=
6.0
,
num_inference_steps
=
20
,
output_type
=
"numpy"
)[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
3
:
,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
expected_slice
=
np
.
array
([
0.9256
,
0.9340
,
0.8933
,
0.9361
,
0.9113
,
0.8727
,
0.9122
,
0.8745
,
0.8099
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
@
slow
@
slow
def
test_ldm_text2img_fast
(
self
):
def
test_ldm_text2img_fast
(
self
):
...
@@ -828,55 +843,44 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -828,55 +843,44 @@ class PipelineTesterMixin(unittest.TestCase):
prompt
=
"A painting of a squirrel eating a burger"
prompt
=
"A painting of a squirrel eating a burger"
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
([
prompt
],
generator
=
generator
,
num_inference_steps
=
1
)
image
=
ldm
([
prompt
],
generator
=
generator
,
num_inference_steps
=
1
,
output_type
=
"numpy"
)[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
3
:
,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
torch
.
tensor
([
0.3163
,
0.8670
,
0.6465
,
0.1865
,
0.6291
,
0.5139
,
0.2824
,
0.3723
,
0.4344
])
expected_slice
=
np
.
array
([
0.3163
,
0.8670
,
0.6465
,
0.1865
,
0.6291
,
0.5139
,
0.2824
,
0.3723
,
0.4344
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
@
slow
@
slow
def
test_score_sde_ve_pipeline
(
self
):
def
test_score_sde_ve_pipeline
(
self
):
model
=
UNetUnconditionalModel
.
from_pretrained
(
"google/ncsnpp-
ffhq-1024
"
)
model
=
UNetUnconditionalModel
.
from_pretrained
(
"google/ncsnpp-
church-256
"
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
scheduler
=
ScoreSdeVeScheduler
.
from_config
(
"google/ncsnpp-
ffhq-1024
"
)
scheduler
=
ScoreSdeVeScheduler
.
from_config
(
"google/ncsnpp-
church-256
"
)
sde_ve
=
ScoreSdeVePipeline
(
model
=
model
,
scheduler
=
scheduler
)
sde_ve
=
ScoreSdeVePipeline
(
model
=
model
,
scheduler
=
scheduler
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
image
=
sde_ve
(
num_inference_steps
=
2
)
image
=
sde_ve
(
num_inference_steps
=
300
,
output_type
=
"numpy"
)[
"sample"
]
if
model
.
device
.
type
==
"cpu"
:
# patrick's cpu
expected_image_sum
=
3384805888.0
expected_image_mean
=
1076.00085
# m1 mbp
image_slice
=
image
[
0
,
-
3
:,
-
3
:,
-
1
]
# expected_image_sum = 3384805376.0
# expected_image_mean = 1076.000610351562
else
:
expected_image_sum
=
3382849024.0
expected_image_mean
=
1075.3788
assert
(
image
.
abs
().
sum
()
-
expected_image_sum
).
abs
().
cpu
().
item
()
<
1e-2
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
assert
(
image
.
abs
().
mean
()
-
expected_image_mean
).
abs
().
cpu
().
item
()
<
1e-4
expected_slice
=
np
.
array
([
0.64363
,
0.5868
,
0.3031
,
0.2284
,
0.7409
,
0.3216
,
0.25643
,
0.6557
,
0.2633
])
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
@
slow
@
slow
def
test_ldm_uncond
(
self
):
def
test_ldm_uncond
(
self
):
ldm
=
LatentDiffusionUncondPipeline
.
from_pretrained
(
"CompVis/ldm-celebahq-256"
)
ldm
=
LatentDiffusionUncondPipeline
.
from_pretrained
(
"CompVis/ldm-celebahq-256"
)
generator
=
torch
.
manual_seed
(
0
)
generator
=
torch
.
manual_seed
(
0
)
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
)[
"sample"
]
image
=
ldm
(
generator
=
generator
,
num_inference_steps
=
5
,
output_type
=
"numpy"
)[
"sample"
]
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
image_slice
=
image
[
0
,
-
3
:
,
-
3
:,
-
1
]
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
torch
.
tensor
(
expected_slice
=
np
.
array
([
0.4399
,
0.44975
,
0.46825
,
0.474
,
0.4359
,
0.4581
,
0.45095
,
0.4341
,
0.4447
])
[
-
0.1202
,
-
0.1005
,
-
0.0635
,
-
0.0520
,
-
0.1282
,
-
0.0838
,
-
0.0981
,
-
0.1318
,
-
0.1106
]
assert
np
.
abs
(
image_slice
.
flatten
()
-
expected_slice
).
max
()
<
1e-2
)
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
tests/test_scheduler.py
View file @
13e37cab
...
@@ -70,7 +70,6 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -70,7 +70,6 @@ class SchedulerCommonTest(unittest.TestCase):
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
residual
=
0.1
*
sample
...
@@ -102,7 +101,6 @@ class SchedulerCommonTest(unittest.TestCase):
...
@@ -102,7 +101,6 @@ class SchedulerCommonTest(unittest.TestCase):
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
residual
=
0.1
*
sample
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
...
@@ -375,33 +373,40 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -375,33 +373,40 @@ class PNDMSchedulerTest(SchedulerCommonTest):
config
.
update
(
**
kwargs
)
config
.
update
(
**
kwargs
)
return
config
return
config
def
check_over_configs
_pmls
(
self
,
time_step
=
0
,
**
config
):
def
check_over_configs
(
self
,
time_step
=
0
,
**
config
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
residual
=
0.1
*
sample
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
(
**
config
)
scheduler_config
=
self
.
get_scheduler_config
(
**
config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
kwargs
[
"num_inference_steps"
])
# copy over dummy past residuals
# copy over dummy past residuals
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler
.
set_plms_mode
()
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
.
set_timesteps
(
kwargs
[
"num_inference_steps"
])
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_plms_mode
()
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
output
=
scheduler
.
step
_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step
_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
check_over_forward_pmls
(
self
,
time_step
=
0
,
**
forward_kwargs
):
output
=
scheduler
.
step_plms
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step_plms
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_from_pretrained_save_pretrained
(
self
):
pass
def
check_over_forward
(
self
,
time_step
=
0
,
**
forward_kwargs
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
.
update
(
forward_kwargs
)
kwargs
.
update
(
forward_kwargs
)
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
...
@@ -409,74 +414,127 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -409,74 +414,127 @@ class PNDMSchedulerTest(SchedulerCommonTest):
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_timesteps
(
kwargs
[
"num_inference_steps"
])
# copy over dummy past residuals
# copy over dummy past residuals
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler
.
set_plms_mode
()
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
scheduler
.
save_config
(
tmpdirname
)
scheduler
.
save_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
new_scheduler
=
scheduler_class
.
from_config
(
tmpdirname
)
# copy over dummy past residuals
# copy over dummy past residuals
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
ets
=
dummy_past_residuals
[:]
new_scheduler
.
set_
plms_mode
(
)
new_scheduler
.
set_
timesteps
(
kwargs
[
"num_inference_steps"
]
)
output
=
scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
output
=
scheduler
.
step_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step_prk
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
output
=
scheduler
.
step_plms
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
new_output
=
new_scheduler
.
step_plms
(
residual
,
time_step
,
sample
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
assert
np
.
sum
(
np
.
abs
(
output
-
new_output
))
<
1e-5
,
"Scheduler outputs are not identical"
def
test_pytorch_equal_numpy
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
sample_pt
=
torch
.
tensor
(
sample
)
residual_pt
=
0.1
*
sample_pt
dummy_past_residuals_pt
=
[
residual_pt
+
0.2
,
residual_pt
+
0.15
,
residual_pt
+
0.1
,
residual_pt
+
0.05
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
# copy over dummy past residuals
scheduler
.
ets
=
dummy_past_residuals
[:]
scheduler_pt
=
scheduler_class
(
tensor_format
=
"pt"
,
**
scheduler_config
)
# copy over dummy past residuals
scheduler_pt
.
ets
=
dummy_past_residuals_pt
[:]
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
scheduler_pt
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output
=
scheduler
.
step_prk
(
residual
,
1
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step_prk
(
residual_pt
,
1
,
sample_pt
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
output
=
scheduler
.
step_plms
(
residual
,
1
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_pt
=
scheduler_pt
.
step_plms
(
residual_pt
,
1
,
sample_pt
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
assert
np
.
sum
(
np
.
abs
(
output
-
output_pt
.
numpy
()))
<
1e-4
,
"Scheduler outputs are not identical"
def
test_step_shape
(
self
):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
num_inference_steps
=
kwargs
.
pop
(
"num_inference_steps"
,
None
)
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
# copy over dummy past residuals
dummy_past_residuals
=
[
residual
+
0.2
,
residual
+
0.15
,
residual
+
0.1
,
residual
+
0.05
]
scheduler
.
ets
=
dummy_past_residuals
[:]
if
num_inference_steps
is
not
None
and
hasattr
(
scheduler
,
"set_timesteps"
):
scheduler
.
set_timesteps
(
num_inference_steps
)
elif
num_inference_steps
is
not
None
and
not
hasattr
(
scheduler
,
"set_timesteps"
):
kwargs
[
"num_inference_steps"
]
=
num_inference_steps
output_0
=
scheduler
.
step_prk
(
residual
,
0
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step_prk
(
residual
,
1
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
output_0
=
scheduler
.
step_plms
(
residual
,
0
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
output_1
=
scheduler
.
step_plms
(
residual
,
1
,
sample
,
num_inference_steps
,
**
kwargs
)[
"prev_sample"
]
self
.
assertEqual
(
output_0
.
shape
,
sample
.
shape
)
self
.
assertEqual
(
output_0
.
shape
,
output_1
.
shape
)
def
test_timesteps
(
self
):
def
test_timesteps
(
self
):
for
timesteps
in
[
100
,
1000
]:
for
timesteps
in
[
100
,
1000
]:
self
.
check_over_configs
(
num_train_timesteps
=
timesteps
)
self
.
check_over_configs
(
num_train_timesteps
=
timesteps
)
def
test_timesteps_pmls
(
self
):
for
timesteps
in
[
100
,
1000
]:
self
.
check_over_configs_pmls
(
num_train_timesteps
=
timesteps
)
def
test_betas
(
self
):
def
test_betas
(
self
):
for
beta_start
,
beta_end
in
zip
([
0.0001
,
0.001
,
0.01
],
[
0.002
,
0.02
,
0.2
]):
for
beta_start
,
beta_end
in
zip
([
0.0001
,
0.001
,
0.01
],
[
0.002
,
0.02
,
0.2
]):
self
.
check_over_configs
(
beta_start
=
beta_start
,
beta_end
=
beta_end
)
self
.
check_over_configs
(
beta_start
=
beta_start
,
beta_end
=
beta_end
)
def
test_betas_pmls
(
self
):
for
beta_start
,
beta_end
in
zip
([
0.0001
,
0.001
,
0.01
],
[
0.002
,
0.02
,
0.2
]):
self
.
check_over_configs_pmls
(
beta_start
=
beta_start
,
beta_end
=
beta_end
)
def
test_schedules
(
self
):
def
test_schedules
(
self
):
for
schedule
in
[
"linear"
,
"squaredcos_cap_v2"
]:
for
schedule
in
[
"linear"
,
"squaredcos_cap_v2"
]:
self
.
check_over_configs
(
beta_schedule
=
schedule
)
self
.
check_over_configs
(
beta_schedule
=
schedule
)
def
test_schedules_pmls
(
self
):
for
schedule
in
[
"linear"
,
"squaredcos_cap_v2"
]:
self
.
check_over_configs
(
beta_schedule
=
schedule
)
def
test_time_indices
(
self
):
def
test_time_indices
(
self
):
for
t
in
[
1
,
5
,
10
]:
for
t
in
[
1
,
5
,
10
]:
self
.
check_over_forward
(
time_step
=
t
)
self
.
check_over_forward
(
time_step
=
t
)
def
test_time_indices_pmls
(
self
):
for
t
in
[
1
,
5
,
10
]:
self
.
check_over_forward_pmls
(
time_step
=
t
)
def
test_inference_steps
(
self
):
def
test_inference_steps
(
self
):
for
t
,
num_inference_steps
in
zip
([
1
,
5
,
10
],
[
10
,
50
,
100
]):
for
t
,
num_inference_steps
in
zip
([
1
,
5
,
10
],
[
10
,
50
,
100
]):
self
.
check_over_forward
(
time_step
=
t
,
num_inference_steps
=
num_inference_steps
)
self
.
check_over_forward
(
time_step
=
t
,
num_inference_steps
=
num_inference_steps
)
def
test_inference_steps_pmls
(
self
):
def
test_inference_plms_no_past_residuals
(
self
):
for
t
,
num_inference_steps
in
zip
([
1
,
5
,
10
],
[
10
,
50
,
100
]):
self
.
check_over_forward_pmls
(
time_step
=
t
,
num_inference_steps
=
num_inference_steps
)
def
test_inference_pmls_no_past_residuals
(
self
):
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
.
set_plms_mode
()
scheduler
.
step_plms
(
self
.
dummy_sample
,
1
,
self
.
dummy_sample
,
50
)[
"prev_sample"
]
scheduler
.
step
(
self
.
dummy_sample
,
1
,
self
.
dummy_sample
,
50
)[
"prev_sample"
]
def
test_full_loop_no_noise
(
self
):
def
test_full_loop_no_noise
(
self
):
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_class
=
self
.
scheduler_classes
[
0
]
...
@@ -486,20 +544,15 @@ class PNDMSchedulerTest(SchedulerCommonTest):
...
@@ -486,20 +544,15 @@ class PNDMSchedulerTest(SchedulerCommonTest):
num_inference_steps
=
10
num_inference_steps
=
10
model
=
self
.
dummy_model
()
model
=
self
.
dummy_model
()
sample
=
self
.
dummy_sample_deter
sample
=
self
.
dummy_sample_deter
scheduler
.
set_timesteps
(
num_inference_steps
)
prk_time_steps
=
scheduler
.
get_prk_time_steps
(
num_inference_steps
)
for
i
,
t
in
enumerate
(
scheduler
.
prk_timesteps
):
for
t
in
range
(
len
(
prk_time_steps
)):
residual
=
model
(
sample
,
t
)
t_orig
=
prk_time_steps
[
t
]
sample
=
scheduler
.
step_prk
(
residual
,
i
,
sample
,
num_inference_steps
)[
"prev_sample"
]
residual
=
model
(
sample
,
t_orig
)
sample
=
scheduler
.
step_prk
(
residual
,
t
,
sample
,
num_inference_steps
)[
"prev_sample"
]
timesteps
=
scheduler
.
get_time_steps
(
num_inference_steps
)
for
t
in
range
(
len
(
timesteps
)):
t_orig
=
timesteps
[
t
]
residual
=
model
(
sample
,
t_orig
)
sample
=
scheduler
.
step_plms
(
residual
,
t
,
sample
,
num_inference_steps
)[
"prev_sample"
]
for
i
,
t
in
enumerate
(
scheduler
.
plms_timesteps
):
residual
=
model
(
sample
,
t
)
sample
=
scheduler
.
step_plms
(
residual
,
i
,
sample
,
num_inference_steps
)[
"prev_sample"
]
result_sum
=
np
.
sum
(
np
.
abs
(
sample
))
result_sum
=
np
.
sum
(
np
.
abs
(
sample
))
result_mean
=
np
.
mean
(
np
.
abs
(
sample
))
result_mean
=
np
.
mean
(
np
.
abs
(
sample
))
...
@@ -562,7 +615,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
...
@@ -562,7 +615,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
kwargs
=
dict
(
self
.
forward_default_kwargs
)
kwargs
=
dict
(
self
.
forward_default_kwargs
)
for
scheduler_class
in
self
.
scheduler_classes
:
for
scheduler_class
in
self
.
scheduler_classes
:
scheduler_class
=
self
.
scheduler_classes
[
0
]
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
residual
=
0.1
*
sample
...
@@ -591,7 +643,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
...
@@ -591,7 +643,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
sample
=
self
.
dummy_sample
sample
=
self
.
dummy_sample
residual
=
0.1
*
sample
residual
=
0.1
*
sample
scheduler_class
=
self
.
scheduler_classes
[
0
]
scheduler_config
=
self
.
get_scheduler_config
()
scheduler_config
=
self
.
get_scheduler_config
()
scheduler
=
scheduler_class
(
**
scheduler_config
)
scheduler
=
scheduler_class
(
**
scheduler_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment