Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b02d0d6b
Commit
b02d0d6b
authored
Jun 09, 2022
by
Patrick von Platen
Browse files
merge
parents
49257b4a
02cdd683
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
262 additions
and
24 deletions
+262
-24
src/diffusers/schedulers/classifier_free_guidance.py
src/diffusers/schedulers/classifier_free_guidance.py
+97
-0
src/diffusers/schedulers/gaussian_ddpm.py
src/diffusers/schedulers/gaussian_ddpm.py
+1
-24
src/diffusers/schedulers/glide_ddim.py
src/diffusers/schedulers/glide_ddim.py
+96
-0
src/diffusers/schedulers/schedulers_utils.py
src/diffusers/schedulers/schedulers_utils.py
+38
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+30
-0
No files found.
src/diffusers/schedulers/classifier_free_guidance.py
0 → 100644
View file @
b02d0d6b
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
import
torch
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
def
linear_beta_schedule
(
timesteps
,
beta_start
,
beta_end
):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
np
.
array
(
betas
,
dtype
=
np
.
float64
)
class
ClassifierFreeGuidanceScheduler
(
nn
.
Module
,
ConfigMixin
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
timesteps
=
1000
,
beta_schedule
=
"squaredcos_cap_v2"
,
):
super
().
__init__
()
self
.
register
(
timesteps
=
timesteps
,
beta_schedule
=
beta_schedule
,
)
self
.
num_timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
self
.
betas
=
betas_for_alpha_bar
(
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
alphas
=
1.0
-
self
.
betas
self
.
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
self
.
alphas_cumprod_prev
=
np
.
append
(
1.0
,
self
.
alphas_cumprod
[:
-
1
])
# calculations for diffusion q(x_t | x_{t-1}) and others
self
.
sqrt_recip_alphas_cumprod
=
np
.
sqrt
(
1.0
/
self
.
alphas_cumprod
)
self
.
sqrt_recipm1_alphas_cumprod
=
np
.
sqrt
(
1.0
/
self
.
alphas_cumprod
-
1
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self
.
posterior_variance
=
self
.
betas
*
(
1.0
-
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self
.
posterior_log_variance_clipped
=
np
.
log
(
np
.
append
(
self
.
posterior_variance
[
1
],
self
.
posterior_variance
[
1
:])
)
self
.
posterior_mean_coef1
=
self
.
betas
*
np
.
sqrt
(
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
self
.
posterior_mean_coef2
=
(
1.0
-
self
.
alphas_cumprod_prev
)
*
np
.
sqrt
(
alphas
)
/
(
1.0
-
self
.
alphas_cumprod
)
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
# always sample on CPU to be deterministic
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
return
self
.
num_timesteps
src/diffusers/schedulers/gaussian_ddpm.py
View file @
b02d0d6b
...
@@ -16,35 +16,12 @@ import math
...
@@ -16,35 +16,12 @@ import math
from
torch
import
nn
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.schedulers_utils
import
linear_beta_schedule
,
betas_for_alpha_bar
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
def
linear_beta_schedule
(
timesteps
,
beta_start
,
beta_end
):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
torch
.
tensor
(
betas
,
dtype
=
torch
.
float64
)
class
GaussianDDPMScheduler
(
nn
.
Module
,
ConfigMixin
):
class
GaussianDDPMScheduler
(
nn
.
Module
,
ConfigMixin
):
config_name
=
SAMPLING_CONFIG_NAME
config_name
=
SAMPLING_CONFIG_NAME
...
...
src/diffusers/schedulers/glide_ddim.py
0 → 100644
View file @
b02d0d6b
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
numpy
as
np
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
from
.schedulers_utils
import
linear_beta_schedule
,
betas_for_alpha_bar
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
class
GlideDDIMScheduler
(
nn
.
Module
,
ConfigMixin
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
timesteps
=
1000
,
beta_schedule
=
"linear"
,
variance_type
=
"fixed_large"
):
super
().
__init__
()
self
.
register
(
timesteps
=
timesteps
,
beta_schedule
=
beta_schedule
,
)
self
.
num_timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"linear"
:
# Linear schedule from Ho et al, extended to work for any number of
# diffusion steps.
scale
=
1000
/
self
.
num_timesteps
beta_start
=
scale
*
0.0001
beta_end
=
scale
*
0.02
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
torch
.
nn
.
functional
.
pad
(
alphas_cumprod
[:
-
1
],
(
1
,
0
),
value
=
1.0
)
variance
=
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
if
variance_type
==
"fixed_small"
:
log_variance
=
torch
.
log
(
variance
.
clamp
(
min
=
1e-20
))
elif
variance_type
==
"fixed_large"
:
log_variance
=
torch
.
log
(
torch
.
cat
([
variance
[
1
:
2
],
betas
[
1
:]],
dim
=
0
))
self
.
register_buffer
(
"betas"
,
betas
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"alphas"
,
alphas
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"alphas_cumprod"
,
alphas_cumprod
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"log_variance"
,
log_variance
.
to
(
torch
.
float32
))
def
get_alpha
(
self
,
time_step
):
return
self
.
alphas
[
time_step
]
def
get_beta
(
self
,
time_step
):
return
self
.
betas
[
time_step
]
def
get_alpha_prod
(
self
,
time_step
):
if
time_step
<
0
:
return
torch
.
tensor
(
1.0
)
return
self
.
alphas_cumprod
[
time_step
]
def
sample_variance
(
self
,
time_step
,
shape
,
device
,
generator
=
None
):
variance
=
self
.
log_variance
[
time_step
]
nonzero_mask
=
torch
.
tensor
([
1
-
(
time_step
==
0
)],
device
=
device
).
float
()[
None
,
:]
noise
=
self
.
sample_noise
(
shape
,
device
=
device
,
generator
=
generator
)
sampled_variance
=
nonzero_mask
*
(
0.5
*
variance
).
exp
()
sampled_variance
=
sampled_variance
*
noise
return
sampled_variance
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
# always sample on CPU to be deterministic
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
return
self
.
num_timesteps
\ No newline at end of file
src/diffusers/schedulers/schedulers_utils.py
0 → 100644
View file @
b02d0d6b
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
def
linear_beta_schedule
(
timesteps
,
beta_start
,
beta_end
):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
torch
.
tensor
(
betas
,
dtype
=
torch
.
float64
)
tests/test_modeling_utils.py
View file @
b02d0d6b
...
@@ -26,6 +26,7 @@ from diffusers import GaussianDDPMScheduler, UNetModel
...
@@ -26,6 +26,7 @@ from diffusers import GaussianDDPMScheduler, UNetModel
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
models.vision.ddpm.modeling_ddpm
import
DDPM
from
models.vision.ddpm.modeling_ddpm
import
DDPM
from
models.vision.ddim.modeling_ddim
import
DDIM
global_rng
=
random
.
Random
()
global_rng
=
random
.
Random
()
...
@@ -245,6 +246,7 @@ class SamplerTesterMixin(unittest.TestCase):
...
@@ -245,6 +246,7 @@ class SamplerTesterMixin(unittest.TestCase):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
# 1. Load models
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
...
@@ -281,3 +283,31 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -281,3 +283,31 @@ class PipelineTesterMixin(unittest.TestCase):
new_image
=
ddpm_from_hub
(
generator
=
generator
)
new_image
=
ddpm_from_hub
(
generator
=
generator
)
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
@
slow
def
test_ddpm_cifar10
(
self
):
generator
=
torch
.
manual_seed
(
0
)
model_id
=
"fusing/ddpm-cifar10"
ddpm
=
DDPM
.
from_pretrained
(
model_id
)
image
=
ddpm
(
generator
=
generator
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
([
0.2250
,
0.3375
,
0.2360
,
0.0930
,
0.3440
,
0.3156
,
0.1937
,
0.3585
,
0.1761
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
@
slow
def
test_ddim_cifar10
(
self
):
generator
=
torch
.
manual_seed
(
0
)
model_id
=
"fusing/ddpm-cifar10"
ddim
=
DDIM
.
from_pretrained
(
model_id
)
image
=
ddim
(
generator
=
generator
,
eta
=
0.0
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
image
.
shape
==
(
1
,
3
,
32
,
32
)
expected_slice
=
torch
.
tensor
([
-
0.7688
,
-
0.7690
,
-
0.7597
,
-
0.7660
,
-
0.7713
,
-
0.7531
,
-
0.7009
,
-
0.7098
,
-
0.7350
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment