Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
542c7868
Commit
542c7868
authored
Jun 14, 2022
by
patil-suraj
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
147d8e07
da1f920e
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
286 additions
and
174 deletions
+286
-174
Makefile
Makefile
+1
-1
_
_
+156
-0
examples/training_ddpm.py
examples/training_ddpm.py
+34
-22
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+2
-5
src/diffusers/pipelines/pipeline_glide.py
src/diffusers/pipelines/pipeline_glide.py
+8
-8
src/diffusers/pipelines/pipeline_pndm.py
src/diffusers/pipelines/pipeline_pndm.py
+12
-87
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+57
-49
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+16
-2
No files found.
Makefile
View file @
542c7868
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export
PYTHONPATH
=
src
export
PYTHONPATH
=
src
check_dirs
:=
tests src utils
check_dirs
:=
examples
tests src utils
modified_only_fixup
:
modified_only_fixup
:
$(
eval
modified_py_files :
=
$(
shell
python utils/get_modified_files.py
$(check_dirs)
))
$(
eval
modified_py_files :
=
$(
shell
python utils/get_modified_files.py
$(check_dirs)
))
...
...
_
0 → 100644
View file @
542c7868
#
Copyright
2022
The
HuggingFace
Team
.
All
rights
reserved
.
#
#
Licensed
under
the
Apache
License
,
Version
2.0
(
the
"License"
);
#
you
may
not
use
this
file
except
in
compliance
with
the
License
.
#
You
may
obtain
a
copy
of
the
License
at
#
#
http
://
www
.
apache
.
org
/
licenses
/
LICENSE
-
2.0
#
#
Unless
required
by
applicable
law
or
agreed
to
in
writing
,
software
#
distributed
under
the
License
is
distributed
on
an
"AS IS"
BASIS
,
#
WITHOUT
WARRANTIES
OR
CONDITIONS
OF
ANY
KIND
,
either
express
or
implied
.
#
See
the
License
for
the
specific
language
governing
permissions
and
#
limitations
under
the
License
.
import
torch
import
tqdm
from
..
pipeline_utils
import
DiffusionPipeline
class
PNDM
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
super
().
__init__
()
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
#
eta
corresponds
to
η
in
paper
and
should
be
between
[
0
,
1
]
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
num_trained_timesteps
=
self
.
noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
self
.
unet
.
to
(
torch_device
)
#
Sample
gaussian
noise
to
begin
loop
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
generator
=
generator
,
)
image
=
image
.
to
(
torch_device
)
seq
=
list
(
inference_step_times
)
seq_next
=
[-
1
]
+
list
(
seq
[:-
1
])
model
=
self
.
unet
warmup_steps
=
[
len
(
seq
)
-
(
i
//
4
+
1
)
for
i
in
range
(
3
*
4
)]
ets
=
[]
prev_image
=
image
for
i
,
step_idx
in
enumerate
(
warmup_steps
):
i
=
seq
[
step_idx
]
j
=
seq_next
[
step_idx
]
t
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
i
)
t_next
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
j
)
residual
=
model
(
image
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
residual
=
residual
.
to
(
"cpu"
)
image
=
image
.
to
(
"cpu"
)
image
=
self
.
noise_scheduler
.
transfer
(
prev_image
.
to
(
"cpu"
),
t_list
[
0
],
t_list
[
1
],
residual
)
if
i
%
4
==
0
:
ets
.
append
(
residual
)
prev_image
=
image
for
ets
=
[]
step_idx
=
len
(
seq
)
-
1
while
step_idx
>=
0
:
i
=
seq
[
step_idx
]
j
=
seq_next
[
step_idx
]
t
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
i
)
t_next
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
j
)
residual
=
model
(
image
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
residual
=
residual
.
to
(
"cpu"
)
t_list
=
[
t
,
(
t
+
t_next
)/
2
,
t_next
]
ets
.
append
(
residual
)
if
len
(
ets
)
<=
3
:
image
=
image
.
to
(
"cpu"
)
x_2
=
self
.
noise_scheduler
.
transfer
(
image
.
to
(
"cpu"
),
t_list
[
0
],
t_list
[
1
],
residual
)
e_2
=
model
(
x_2
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
)).
to
(
"cpu"
)
x_3
=
self
.
noise_scheduler
.
transfer
(
image
,
t_list
[
0
],
t_list
[
1
],
e_2
)
e_3
=
model
(
x_3
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
)).
to
(
"cpu"
)
x_4
=
self
.
noise_scheduler
.
transfer
(
image
,
t_list
[
0
],
t_list
[
2
],
e_3
)
e_4
=
model
(
x_4
.
to
(
"cuda"
),
t_list
[
2
].
to
(
"cuda"
)).
to
(
"cpu"
)
residual
=
(
1
/
6
)
*
(
residual
+
2
*
e_2
+
2
*
e_3
+
e_4
)
else
:
residual
=
(
1
/
24
)
*
(
55
*
ets
[-
1
]
-
59
*
ets
[-
2
]
+
37
*
ets
[-
3
]
-
9
*
ets
[-
4
])
img_next
=
self
.
noise_scheduler
.
transfer
(
image
.
to
(
"cpu"
),
t
,
t_next
,
residual
)
image
=
img_next
step_idx
=
step_idx
-
1
#
if
len
(
prev_noises
)
in
[
1
,
2
]:
#
t
=
(
t
+
t_next
)
/
2
#
elif
len
(
prev_noises
)
==
3
:
#
t
=
t_next
/
2
#
if
len
(
prev_noises
)
==
0
:
#
ets
.
append
(
residual
)
#
#
if
len
(
ets
)
>
3
:
#
residual
=
(
1
/
24
)
*
(
55
*
ets
[-
1
]
-
59
*
ets
[-
2
]
+
37
*
ets
[-
3
]
-
9
*
ets
[-
4
])
#
step_idx
=
step_idx
-
1
#
elif
len
(
ets
)
<=
3
and
len
(
prev_noises
)
==
3
:
#
residual
=
(
1
/
6
)
*
(
prev_noises
[-
3
]
+
2
*
prev_noises
[-
2
]
+
2
*
prev_noises
[-
1
]
+
residual
)
#
prev_noises
=
[]
#
step_idx
=
step_idx
-
1
#
elif
len
(
ets
)
<=
3
and
len
(
prev_noises
)
<
3
:
#
prev_noises
.
append
(
residual
)
#
if
len
(
prev_noises
)
<
2
:
#
t_next
=
(
t
+
t_next
)
/
2
#
#
image
=
self
.
noise_scheduler
.
transfer
(
image
.
to
(
"cpu"
),
t
,
t_next
,
residual
)
return
image
#
See
formulas
(
12
)
and
(
16
)
of
DDIM
paper
https
://
arxiv
.
org
/
pdf
/
2010.02502
.
pdf
#
Ideally
,
read
DDIM
paper
in
-
detail
understanding
#
Notation
(<
variable
name
>
->
<
name
in
paper
>
#
-
pred_noise_t
->
e_theta
(
x_t
,
t
)
#
-
pred_original_image
->
f_theta
(
x_t
,
t
)
or
x_0
#
-
std_dev_t
->
sigma_t
#
-
eta
->
η
#
-
pred_image_direction
->
"direction pointingc to x_t"
#
-
pred_prev_image
->
"x_t-1"
#
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
#
1.
predict
noise
residual
#
with
torch
.
no_grad
():
#
residual
=
self
.
unet
(
image
,
inference_step_times
[
t
])
#
#
2.
predict
previous
mean
of
image
x_t
-
1
#
pred_prev_image
=
self
.
noise_scheduler
.
step
(
residual
,
image
,
t
,
num_inference_steps
,
eta
)
#
#
3.
optionally
sample
variance
#
variance
=
0
#
if
eta
>
0
:
#
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
#
variance
=
self
.
noise_scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
#
#
4.
set
current
image
to
prev_image
:
x_t
->
x_t
-
1
#
image
=
pred_prev_image
+
variance
src/diffusers/trainer
s/training_ddpm.py
→
example
s/training_ddpm.py
View file @
542c7868
...
@@ -8,14 +8,23 @@ import PIL.Image
...
@@ -8,14 +8,23 @@ import PIL.Image
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
torchvision.transforms
import
CenterCrop
,
Compose
,
Lambda
,
RandomHorizontalFlip
,
Resize
,
ToTensor
from
torchvision.transforms
import
(
Compose
,
InterpolationMode
,
Lambda
,
RandomCrop
,
RandomHorizontalFlip
,
RandomVerticalFlip
,
Resize
,
ToTensor
,
)
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
transformers
import
get_linear_schedule_with_warmup
from
transformers
import
get_linear_schedule_with_warmup
def
set_seed
(
seed
):
def
set_seed
(
seed
):
torch
.
backends
.
cudnn
.
deterministic
=
True
#
torch.backends.cudnn.deterministic = True
torch
.
backends
.
cudnn
.
benchmark
=
False
#
torch.backends.cudnn.benchmark = False
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
...
@@ -30,13 +39,13 @@ model = UNetModel(
...
@@ -30,13 +39,13 @@ model = UNetModel(
attn_resolutions
=
(
16
,),
attn_resolutions
=
(
16
,),
ch
=
128
,
ch
=
128
,
ch_mult
=
(
1
,
2
,
2
,
2
),
ch_mult
=
(
1
,
2
,
2
,
2
),
dropout
=
0.
1
,
dropout
=
0.
0
,
num_res_blocks
=
2
,
num_res_blocks
=
2
,
resamp_with_conv
=
True
,
resamp_with_conv
=
True
,
resolution
=
32
resolution
=
32
,
)
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
1000
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.0002
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
3e-4
)
num_epochs
=
100
num_epochs
=
100
batch_size
=
64
batch_size
=
64
...
@@ -44,14 +53,15 @@ gradient_accumulation_steps = 2
...
@@ -44,14 +53,15 @@ gradient_accumulation_steps = 2
augmentations
=
Compose
(
augmentations
=
Compose
(
[
[
Resize
(
32
),
Resize
(
32
,
interpolation
=
InterpolationMode
.
BILINEAR
),
CenterCrop
(
32
),
RandomHorizontalFlip
(),
RandomHorizontalFlip
(),
RandomVerticalFlip
(),
RandomCrop
(
32
),
ToTensor
(),
ToTensor
(),
Lambda
(
lambda
x
:
x
*
2
-
1
),
Lambda
(
lambda
x
:
x
*
2
-
1
),
]
]
)
)
dataset
=
load_dataset
(
"huggan/
pokemon
"
,
split
=
"train"
)
dataset
=
load_dataset
(
"huggan/
flowers-102-categories
"
,
split
=
"train"
)
def
transforms
(
examples
):
def
transforms
(
examples
):
...
@@ -59,24 +69,24 @@ def transforms(examples):
...
@@ -59,24 +69,24 @@ def transforms(examples):
return
{
"input"
:
images
}
return
{
"input"
:
images
}
dataset
=
dataset
.
shuffle
(
seed
=
0
)
dataset
.
set_transform
(
transforms
)
dataset
.
set_transform
(
transforms
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
Fals
e
)
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
Tru
e
)
#
lr_scheduler = get_linear_schedule_with_warmup(
lr_scheduler
=
get_linear_schedule_with_warmup
(
#
optimizer=optimizer,
optimizer
=
optimizer
,
#
num_warmup_steps=
10
00,
num_warmup_steps
=
5
00
,
#
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
num_training_steps
=
(
len
(
train_dataloader
)
*
num_epochs
)
//
gradient_accumulation_steps
,
#
)
)
model
,
optimizer
,
train_dataloader
=
accelerator
.
prepare
(
model
,
optimizer
,
train_dataloader
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
optimizer
,
train_dataloader
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
)
for
epoch
in
range
(
num_epochs
):
for
epoch
in
range
(
num_epochs
):
model
.
train
()
model
.
train
()
pbar
=
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
pbar
=
tqdm
(
total
=
len
(
train_dataloader
),
unit
=
"ba"
)
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
pbar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
losses
=
[]
for
step
,
batch
in
enumerate
(
train_dataloader
):
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"input"
]
clean_images
=
batch
[
"input"
]
noisy_images
=
torch
.
empty_like
(
clean_images
)
noisy_images
=
torch
.
empty_like
(
clean_images
)
...
@@ -101,10 +111,12 @@ for epoch in range(num_epochs):
...
@@ -101,10 +111,12 @@ for epoch in range(num_epochs):
accelerator
.
backward
(
loss
)
accelerator
.
backward
(
loss
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
optimizer
.
step
()
optimizer
.
step
()
#
lr_scheduler.step()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
=
loss
.
detach
().
item
()
losses
.
append
(
loss
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
pbar
.
set_postfix
(
loss
=
loss
.
detach
().
item
(
),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
pbar
.
set_postfix
(
loss
=
loss
,
avg_loss
=
np
.
mean
(
losses
),
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
optimizer
.
step
()
optimizer
.
step
()
...
@@ -124,5 +136,5 @@ for epoch in range(num_epochs):
...
@@ -124,5 +136,5 @@ for epoch in range(num_epochs):
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
image_pil
=
PIL
.
Image
.
fromarray
(
image_processed
[
0
])
# save image
# save image
pipeline
.
save_pretrained
(
"./
poke
-ddpm"
)
pipeline
.
save_pretrained
(
"./
flowers
-ddpm"
)
image_pil
.
save
(
f
"./
poke
-ddpm/test_
{
epoch
}
.png"
)
image_pil
.
save
(
f
"./
flowers
-ddpm/test_
{
epoch
}
.png"
)
src/diffusers/configuration_utils.py
View file @
542c7868
...
@@ -225,11 +225,8 @@ class ConfigMixin:
...
@@ -225,11 +225,8 @@ class ConfigMixin:
text
=
reader
.
read
()
text
=
reader
.
read
()
return
json
.
loads
(
text
)
return
json
.
loads
(
text
)
# def __eq__(self, other):
def
__repr__
(
self
):
# return self.__dict__ == other.__dict__
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
# def __repr__(self):
# return f"{self.__class__.__name__} {self.to_json_string()}"
@
property
@
property
def
config
(
self
)
->
Dict
[
str
,
Any
]:
def
config
(
self
)
->
Dict
[
str
,
Any
]:
...
...
src/diffusers/pipelines/pipeline_glide.py
View file @
542c7868
...
@@ -832,12 +832,12 @@ class GLIDE(DiffusionPipeline):
...
@@ -832,12 +832,12 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise
# 1. Sample gaussian noise
batch_size
=
2
# second image is empty for classifier-free guidance
batch_size
=
2
# second image is empty for classifier-free guidance
image
=
self
.
text_noise_scheduler
.
sample_noise
(
image
=
torch
.
randn
(
(
batch_size
,
self
.
text_unet
.
in_channels
,
64
,
64
),
device
=
torch_device
,
generator
=
generator
(
batch_size
,
self
.
text_unet
.
in_channels
,
64
,
64
),
generator
=
generator
)
)
.
to
(
torch_device
)
# 2. Encode tokens
# 2. Encode tokens
# an empty input is needed to guide the model away from
(
# an empty input is needed to guide the model away from
it
inputs
=
self
.
tokenizer
([
prompt
,
""
],
padding
=
"max_length"
,
max_length
=
128
,
return_tensors
=
"pt"
)
inputs
=
self
.
tokenizer
([
prompt
,
""
],
padding
=
"max_length"
,
max_length
=
128
,
return_tensors
=
"pt"
)
input_ids
=
inputs
[
"input_ids"
].
to
(
torch_device
)
input_ids
=
inputs
[
"input_ids"
].
to
(
torch_device
)
attention_mask
=
inputs
[
"attention_mask"
].
to
(
torch_device
)
attention_mask
=
inputs
[
"attention_mask"
].
to
(
torch_device
)
...
@@ -850,7 +850,7 @@ class GLIDE(DiffusionPipeline):
...
@@ -850,7 +850,7 @@ class GLIDE(DiffusionPipeline):
mean
,
variance
,
log_variance
,
pred_xstart
=
self
.
p_mean_variance
(
mean
,
variance
,
log_variance
,
pred_xstart
=
self
.
p_mean_variance
(
text_model_fn
,
self
.
text_noise_scheduler
,
image
,
t
,
transformer_out
=
transformer_out
text_model_fn
,
self
.
text_noise_scheduler
,
image
,
t
,
transformer_out
=
transformer_out
)
)
noise
=
self
.
text_noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
torch_devic
e
,
generator
=
generator
)
noise
=
torch
.
randn
(
image
.
shap
e
,
generator
=
generator
)
.
to
(
torch_device
)
nonzero_mask
=
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
# no noise when t == 0
nonzero_mask
=
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
# no noise when t == 0
image
=
mean
+
nonzero_mask
*
torch
.
exp
(
0.5
*
log_variance
)
*
noise
image
=
mean
+
nonzero_mask
*
torch
.
exp
(
0.5
*
log_variance
)
*
noise
...
@@ -873,8 +873,8 @@ class GLIDE(DiffusionPipeline):
...
@@ -873,8 +873,8 @@ class GLIDE(DiffusionPipeline):
self
.
upscale_unet
.
resolution
,
self
.
upscale_unet
.
resolution
,
),
),
generator
=
generator
,
generator
=
generator
,
)
)
.
to
(
torch_device
)
image
=
image
.
to
(
torch_device
)
*
upsample_temp
image
=
image
*
upsample_temp
num_trained_timesteps
=
self
.
upscale_noise_scheduler
.
timesteps
num_trained_timesteps
=
self
.
upscale_noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps_upscale
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps_upscale
)
...
@@ -896,7 +896,7 @@ class GLIDE(DiffusionPipeline):
...
@@ -896,7 +896,7 @@ class GLIDE(DiffusionPipeline):
# 3. optionally sample variance
# 3. optionally sample variance
variance
=
0
variance
=
0
if
eta
>
0
:
if
eta
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
torch_
device
)
variance
=
(
variance
=
(
self
.
upscale_noise_scheduler
.
get_variance
(
t
,
num_inference_steps_upscale
).
sqrt
()
*
eta
*
noise
self
.
upscale_noise_scheduler
.
get_variance
(
t
,
num_inference_steps_upscale
).
sqrt
()
*
eta
*
noise
)
)
...
...
src/diffusers/pipelines/pipeline_pndm.py
View file @
542c7868
...
@@ -28,13 +28,11 @@ class PNDM(DiffusionPipeline):
...
@@ -28,13 +28,11 @@ class PNDM(DiffusionPipeline):
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps
=
50
):
# eta corresponds to η in paper and should be between [0, 1]
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
num_trained_timesteps
=
self
.
noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
# Sample gaussian noise to begin loop
# Sample gaussian noise to begin loop
...
@@ -44,91 +42,18 @@ class PNDM(DiffusionPipeline):
...
@@ -44,91 +42,18 @@ class PNDM(DiffusionPipeline):
)
)
image
=
image
.
to
(
torch_device
)
image
=
image
.
to
(
torch_device
)
seq
=
list
(
inference_step_times
)
warmup_time_steps
=
self
.
noise_scheduler
.
get_warmup_time_steps
(
num_inference_steps
)
seq_next
=
[
-
1
]
+
list
(
seq
[:
-
1
])
for
t
in
tqdm
.
tqdm
(
range
(
len
(
warmup_time_steps
))):
model
=
self
.
unet
t_orig
=
warmup_time_steps
[
t
]
residual
=
self
.
unet
(
image
,
t_orig
)
ets
=
[]
prev_noises
=
[]
step_idx
=
len
(
seq
)
-
1
while
step_idx
>=
0
:
i
=
seq
[
step_idx
]
j
=
seq_next
[
step_idx
]
t
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
i
)
t_next
=
(
torch
.
ones
(
image
.
shape
[
0
])
*
j
)
residual
=
model
(
image
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
residual
=
residual
.
to
(
"cpu"
)
t_list
=
[
t
,
(
t
+
t_next
)
/
2
,
t_next
]
image
=
self
.
noise_scheduler
.
step_prk
(
residual
,
image
,
t
,
num_inference_steps
)
ets
.
append
(
residual
)
timesteps
=
self
.
noise_scheduler
.
get_time_steps
(
num_inference_steps
)
if
len
(
ets
)
<=
3
:
for
t
in
tqdm
.
tqdm
(
range
(
len
(
timesteps
)))
:
image
=
im
age
.
to
(
"cpu"
)
t_orig
=
t
im
esteps
[
t
]
x_2
=
self
.
noise_scheduler
.
transfer
(
image
.
to
(
"cpu"
),
t_list
[
0
],
t_list
[
1
],
residual
)
residual
=
self
.
unet
(
image
,
t_orig
)
e_2
=
model
(
x_2
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
)).
to
(
"cpu"
)
image
=
self
.
noise_scheduler
.
step_plms
(
residual
,
image
,
t
,
num_inference_steps
)
x_3
=
self
.
noise_scheduler
.
transfer
(
image
,
t_list
[
0
],
t_list
[
1
],
e_2
)
e_3
=
model
(
x_3
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
)).
to
(
"cpu"
)
x_4
=
self
.
noise_scheduler
.
transfer
(
image
,
t_list
[
0
],
t_list
[
2
],
e_3
)
e_4
=
model
(
x_4
.
to
(
"cuda"
),
t_list
[
2
].
to
(
"cuda"
)).
to
(
"cpu"
)
residual
=
(
1
/
6
)
*
(
residual
+
2
*
e_2
+
2
*
e_3
+
e_4
)
else
:
residual
=
(
1
/
24
)
*
(
55
*
ets
[
-
1
]
-
59
*
ets
[
-
2
]
+
37
*
ets
[
-
3
]
-
9
*
ets
[
-
4
])
img_next
=
self
.
noise_scheduler
.
transfer
(
image
.
to
(
"cpu"
),
t
,
t_next
,
residual
)
image
=
img_next
step_idx
=
step_idx
-
1
# if len(prev_noises) in [1, 2]:
# t = (t + t_next) / 2
# elif len(prev_noises) == 3:
# t = t_next / 2
# if len(prev_noises) == 0:
# ets.append(residual)
#
# if len(ets) > 3:
# residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
# step_idx = step_idx - 1
# elif len(ets) <= 3 and len(prev_noises) == 3:
# residual = (1 / 6) * (prev_noises[-3] + 2 * prev_noises[-2] + 2 * prev_noises[-1] + residual)
# prev_noises = []
# step_idx = step_idx - 1
# elif len(ets) <= 3 and len(prev_noises) < 3:
# prev_noises.append(residual)
# if len(prev_noises) < 2:
# t_next = (t + t_next) / 2
#
# image = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
return
image
return
image
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
# for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# 1. predict noise residual
# with torch.no_grad():
# residual = self.unet(image, inference_step_times[t])
#
# 2. predict previous mean of image x_t-1
# pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta)
#
# 3. optionally sample variance
# variance = 0
# if eta > 0:
# noise = torch.randn(image.shape, generator=generator).to(image.device)
# variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
#
# 4. set current image to prev_image: x_t -> x_t-1
# image = pred_prev_image + variance
src/diffusers/schedulers/scheduling_pndm.py
View file @
542c7868
...
@@ -55,22 +55,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -55,22 +55,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
set_format
(
tensor_format
=
tensor_format
)
# self.register_buffer("betas", betas.to(torch.float32))
# For now we only support F-PNDM, i.e. the runge-kutta method
# self.register_buffer("alphas", alphas.to(torch.float32))
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# mainly at equations (12) and (13) and the Algorithm 2.
self
.
pndm_order
=
4
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# running values
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
self
.
cur_residual
=
0
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
self
.
cur_image
=
None
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
self
.
ets
=
[]
# if variance_type == "fixed_small":
self
.
warmup_time_steps
=
{}
# log_variance = torch.log(variance.clamp(min=1e-20))
self
.
time_steps
=
{}
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def
get_alpha
(
self
,
time_step
):
def
get_alpha
(
self
,
time_step
):
return
self
.
alphas
[
time_step
]
return
self
.
alphas
[
time_step
]
...
@@ -83,51 +78,64 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -83,51 +78,64 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return
self
.
one
return
self
.
one
return
self
.
alphas_cumprod
[
time_step
]
return
self
.
alphas_cumprod
[
time_step
]
def
step
(
self
,
img
,
t_start
,
t_end
,
model
,
ets
):
def
get_warmup_time_steps
(
self
,
num_inference_steps
):
# img_next = self.method(img_n, t_start, t_end, model, self.alphas_cump, self.ets)
if
num_inference_steps
in
self
.
warmup_time_steps
:
#def gen_order_4(img, t, t_next, model, alphas_cump, ets):
return
self
.
warmup_time_steps
[
num_inference_steps
]
t_next
,
t
=
t_start
,
t_end
noise_
=
model
(
img
.
to
(
"cuda"
),
t
.
to
(
"cuda"
))
inference_step_times
=
list
(
range
(
0
,
self
.
timesteps
,
self
.
timesteps
//
num_inference_steps
))
noise_
=
noise_
.
to
(
"cpu"
)
t_list
=
[
t
,
(
t
+
t_next
)
/
2
,
t_next
]
warmup_time_steps
=
np
.
array
(
inference_step_times
[
-
self
.
pndm_order
:]).
repeat
(
2
)
+
np
.
tile
(
np
.
array
([
0
,
self
.
timesteps
//
num_inference_steps
//
2
]),
self
.
pndm_order
)
if
len
(
ets
)
>
2
:
self
.
warmup_time_steps
[
num_inference_steps
]
=
list
(
reversed
(
warmup_time_steps
[:
-
1
].
repeat
(
2
)[
1
:
-
1
]))
ets
.
append
(
noise_
)
noise
=
(
1
/
24
)
*
(
55
*
ets
[
-
1
]
-
59
*
ets
[
-
2
]
+
37
*
ets
[
-
3
]
-
9
*
ets
[
-
4
])
return
self
.
warmup_time_steps
[
num_inference_steps
]
else
:
noise
=
self
.
runge_kutta
(
img
,
t_list
,
model
,
ets
,
noise_
)
def
get_time_steps
(
self
,
num_inference_steps
):
if
num_inference_steps
in
self
.
time_steps
:
return
self
.
time_steps
[
num_inference_steps
]
i
mg_next
=
self
.
transfer
(
img
.
to
(
"cpu"
),
t
,
t_next
,
noise
)
i
nference_step_times
=
list
(
range
(
0
,
self
.
timesteps
,
self
.
timesteps
//
num_inference_steps
)
)
return
img_next
,
ets
self
.
time_steps
[
num_inference_steps
]
=
list
(
reversed
(
inference_step_times
[:
-
3
]))
def
runge_kutta
(
self
,
x
,
t_list
,
model
,
ets
,
noise_
):
return
self
.
time_steps
[
num_inference_steps
]
model
=
model
.
to
(
"cuda"
)
x
=
x
.
to
(
"cpu"
)
e_1
=
noise_
def
step_prk
(
self
,
residual
,
image
,
t
,
num_inference_steps
):
ets
.
append
(
e_1
)
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
x_2
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
1
],
e_1
)
warmup_time_steps
=
self
.
get_warmup_time_steps
(
num_inference_steps
)
e_2
=
model
(
x_2
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
))
t_prev
=
warmup_time_steps
[
t
//
4
*
4
]
e_2
=
e_2
.
to
(
"cpu"
)
t_next
=
warmup_time_steps
[
min
(
t
+
1
,
len
(
warmup_time_steps
)
-
1
)]
x_3
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
1
],
e_2
)
e_3
=
model
(
x_3
.
to
(
"cuda"
),
t_list
[
1
].
to
(
"cuda"
))
if
t
%
4
==
0
:
e_3
=
e_3
.
to
(
"cpu"
)
self
.
cur_residual
+=
1
/
6
*
residual
x_4
=
self
.
transfer
(
x
,
t_list
[
0
],
t_list
[
2
],
e_3
)
self
.
ets
.
append
(
residual
)
self
.
cur_image
=
image
elif
(
t
-
1
)
%
4
==
0
:
self
.
cur_residual
+=
1
/
3
*
residual
elif
(
t
-
2
)
%
4
==
0
:
self
.
cur_residual
+=
1
/
3
*
residual
elif
(
t
-
3
)
%
4
==
0
:
residual
=
self
.
cur_residual
+
1
/
6
*
residual
self
.
cur_residual
=
0
e_4
=
model
(
x_4
.
to
(
"cuda"
),
t_list
[
2
].
to
(
"cuda"
))
return
self
.
transfer
(
self
.
cur_image
,
t_prev
,
t_next
,
residual
)
e_4
=
e_4
.
to
(
"cpu"
)
et
=
(
1
/
6
)
*
(
e_1
+
2
*
e_2
+
2
*
e_3
+
e_4
)
def
step_plms
(
self
,
residual
,
image
,
t
,
num_inference_steps
):
timesteps
=
self
.
get_time_steps
(
num_inference_steps
)
return
et
t_prev
=
timesteps
[
t
]
t_next
=
timesteps
[
min
(
t
+
1
,
len
(
timesteps
)
-
1
)]
self
.
ets
.
append
(
residual
)
residual
=
(
1
/
24
)
*
(
55
*
self
.
ets
[
-
1
]
-
59
*
self
.
ets
[
-
2
]
+
37
*
self
.
ets
[
-
3
]
-
9
*
self
.
ets
[
-
4
])
return
self
.
transfer
(
image
,
t_prev
,
t_next
,
residual
)
def
transfer
(
self
,
x
,
t
,
t_next
,
et
):
def
transfer
(
self
,
x
,
t
,
t_next
,
et
):
alphas_cump
=
self
.
alphas_cumprod
# TODO(Patrick): clean up to be compatible with numpy and give better names
at
=
alphas_cump
[
t
.
long
()
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
.
long
()
+
1
].
view
(
-
1
,
1
,
1
,
1
)
alphas_cump
=
self
.
alphas_cumprod
.
to
(
x
.
device
)
at
=
alphas_cump
[
t
+
1
].
view
(
-
1
,
1
,
1
,
1
)
at_next
=
alphas_cump
[
t_next
+
1
].
view
(
-
1
,
1
,
1
,
1
)
x_delta
=
(
at_next
-
at
)
*
((
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
)
x_delta
=
(
at_next
-
at
)
*
((
1
/
(
at
.
sqrt
()
*
(
at
.
sqrt
()
+
at_next
.
sqrt
())))
*
x
-
1
/
(
at
.
sqrt
()
*
(((
1
-
at_next
)
*
at
).
sqrt
()
+
((
1
-
at
)
*
at_next
).
sqrt
()))
*
et
)
...
...
tests/test_modeling_utils.py
View file @
542c7868
...
@@ -19,7 +19,7 @@ import unittest
...
@@ -19,7 +19,7 @@ import unittest
import
torch
import
torch
from
diffusers
import
DDIM
,
DDPM
,
BDDM
,
DDIMScheduler
,
DDPMScheduler
,
LatentDiffusion
,
UNetModel
,
PNDM
,
PNDMScheduler
from
diffusers
import
DDIM
,
DDPM
,
PNDM
,
GLIDE
,
BDDM
,
DDIMScheduler
,
DDPMScheduler
,
LatentDiffusion
,
PNDMScheduler
,
UNetModel
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.pipeline_bddm
import
DiffWave
from
diffusers.pipelines.pipeline_bddm
import
DiffWave
...
@@ -229,3 +229,17 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -229,3 +229,17 @@ class PipelineTesterMixin(unittest.TestCase):
_
=
BDDM
.
from_pretrained
(
tmpdirname
)
_
=
BDDM
.
from_pretrained
(
tmpdirname
)
# check if the same works using the DifusionPipeline class
# check if the same works using the DifusionPipeline class
_
=
DiffusionPipeline
.
from_pretrained
(
tmpdirname
)
_
=
DiffusionPipeline
.
from_pretrained
(
tmpdirname
)
@
slow
def
test_glide_text2img
(
self
):
model_id
=
"fusing/glide-base"
glide
=
GLIDE
.
from_pretrained
(
model_id
)
prompt
=
"a pencil sketch of a corgi"
generator
=
torch
.
manual_seed
(
0
)
image
=
glide
(
prompt
,
generator
=
generator
,
num_inference_steps_upscale
=
20
)
image_slice
=
image
[
0
,
:
3
,
:
3
,
-
1
].
cpu
()
assert
image
.
shape
==
(
1
,
256
,
256
,
3
)
expected_slice
=
torch
.
tensor
([
0.7119
,
0.7073
,
0.6460
,
0.7780
,
0.7423
,
0.6926
,
0.7378
,
0.7189
,
0.7784
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment