Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
fe313730
Commit
fe313730
authored
Jun 06, 2022
by
Patrick von Platen
Browse files
improve
parent
3a5c65d5
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
691 additions
and
760 deletions
+691
-760
README.md
README.md
+4
-4
examples/sample_loop.py
examples/sample_loop.py
+142
-84
models/vision/ddpm/example.py
models/vision/ddpm/example.py
+5
-3
models/vision/ddpm/modeling_ddpm.py
models/vision/ddpm/modeling_ddpm.py
+0
-1
models/vision/ddpm/run_ddpm.py
models/vision/ddpm/run_ddpm.py
+2
-2
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-3
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+11
-7
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+1
-1
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+298
-309
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+4
-4
src/diffusers/samplers/gaussian.py
src/diffusers/samplers/gaussian.py
+0
-313
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-1
src/diffusers/schedulers/gaussian_ddpm.py
src/diffusers/schedulers/gaussian_ddpm.py
+98
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+123
-28
No files found.
README.md
View file @
fe313730
...
@@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th
...
@@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th
Example:
Example:
```
python
```
python
from
diffusers
import
UNetModel
,
GaussianD
iffusion
from
diffusers
import
UNetModel
,
GaussianD
DPMScheduler
import
torch
import
torch
# 1. Load model
# 1. Load model
...
@@ -40,7 +40,7 @@ time_step = torch.tensor([10])
...
@@ -40,7 +40,7 @@ time_step = torch.tensor([10])
image
=
unet
(
dummy_noise
,
time_step
)
image
=
unet
(
dummy_noise
,
time_step
)
# 3. Load sampler
# 3. Load sampler
sampler
=
GaussianD
iffusion
.
from_config
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
DPMScheduler
.
from_config
(
"fusing/ddpm_dummy"
)
# 4. Sample image from sampler passing the model
# 4. Sample image from sampler passing the model
image
=
sampler
.
sample
(
model
,
batch_size
=
1
)
image
=
sampler
.
sample
(
model
,
batch_size
=
1
)
...
@@ -54,12 +54,12 @@ print(image)
...
@@ -54,12 +54,12 @@ print(image)
Example:
Example:
```
python
```
python
from
diffusers
import
UNetModel
,
GaussianD
iffusion
from
diffusers
import
UNetModel
,
GaussianD
DPMScheduler
from
modeling_ddpm
import
DDPM
from
modeling_ddpm
import
DDPM
import
tempfile
import
tempfile
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
iffusion
.
from_config
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
DPMScheduler
.
from_config
(
"fusing/ddpm_dummy"
)
# compose Diffusion Pipeline
# compose Diffusion Pipeline
ddpm
=
DDPM
(
unet
,
sampler
)
ddpm
=
DDPM
(
unet
,
sampler
)
...
...
examples/sample_loop.py
View file @
fe313730
#!/usr/bin/env python3
#!/usr/bin/env python3
from
diffusers
import
UNetModel
,
GaussianD
iffusion
from
diffusers
import
UNetModel
,
GaussianD
DPMScheduler
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
numpy
as
np
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
import
PIL.Image
diffusion
=
GaussianDiffusion
.
from_config
(
"fusing/ddpm_dummy"
)
import
tqdm
#torch_device = "cuda"
#
#unet = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church")
#unet.to(torch_device)
#
#TIME_STEPS = 10
#
#scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=TIME_STEPS)
#
#diffusion_config = {
# "beta_start": 0.0001,
# "beta_end": 0.02,
# "num_diffusion_timesteps": TIME_STEPS,
#}
#
# 2. Do one denoising step with model
# 2. Do one denoising step with model
batch_size
,
num_channels
,
height
,
width
=
1
,
3
,
32
,
32
#batch_size, num_channels, height, width = 1, 3, 256, 256
dummy_noise
=
torch
.
ones
((
batch_size
,
num_channels
,
height
,
width
))
#
#torch.manual_seed(0)
#noise_image = torch.randn(batch_size, num_channels, height, width, device="cuda")
TIME_STEPS
=
10
#
#
# Helper
# Helper
def
extract
(
a
,
t
,
x_shape
):
#def noise_like(shape, device, repeat=False):
b
,
*
_
=
t
.
shape
# def repeat_noise():
out
=
a
.
gather
(
-
1
,
t
)
# return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
return
out
.
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_shape
)
-
1
)))
#
# def noise():
# return torch.randn(shape, device=device)
#
# return repeat_noise() if repeat else noise()
#
#
#betas = np.linspace(diffusion_config["beta_start"], diffusion_config["beta_end"], diffusion_config["num_diffusion_timesteps"], dtype=np.float64)
#betas = torch.tensor(betas, device=torch_device)
#alphas = 1.0 - betas
#
#alphas_cumprod = torch.cumprod(alphas, axis=0)
#alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
#
#posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
#
#posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
#
#
#sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
#sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
#
#
#noise_coeff = (1 - alphas) / torch.sqrt(1 - alphas_cumprod)
#coeff = 1 / torch.sqrt(alphas)
def
real_fn
():
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
x_t
=
noise_image
# 2: for t = T, ...., 1 do
for
i
in
reversed
(
range
(
TIME_STEPS
)):
t
=
torch
.
tensor
([
i
]).
to
(
torch_device
)
# 3: z ~ N(0, 1)
noise
=
noise_like
(
x_t
.
shape
,
torch_device
)
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
with
torch
.
no_grad
():
pred_noise
=
unet
(
x_t
,
t
)
# pred epsilon_theta
# pred_x = sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * pred_noise
# pred_x.clamp_(-1.0, 1.0)
# pred mean
# posterior_mean = posterior_mean_coef1[t] * pred_x + posterior_mean_coef2[t] * x_t
# --------------------------------------------------------------------#
posterior_mean
=
coeff
[
t
]
*
(
x_t
-
noise_coeff
[
t
]
*
pred_noise
)
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance
=
posterior_log_variance_clipped
[
t
]
b
,
*
_
,
device
=
*
x_t
.
shape
,
x_t
.
device
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_t
.
shape
)
-
1
)))
posterior_variance
=
nonzero_mask
*
(
0.5
*
posterior_log_variance
).
exp
()
# --------------------------------------------------------------------#
x_t_1
=
(
posterior_mean
+
posterior_variance
*
noise
).
to
(
torch
.
float32
)
x_t
=
x_t_1
print
(
x_t
.
abs
().
sum
())
def
post_process_to_image
(
x_t
):
image
=
x_t
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image
=
(
image
+
1.0
)
*
127.5
image
=
image
.
numpy
().
astype
(
np
.
uint8
)
return
PIL
.
Image
.
fromarray
(
image
[
0
])
from
pytorch_diffusion
import
Diffusion
#diffusion = Diffusion.from_pretrained("lsun_church")
#samples = diffusion.denoise(1)
#
#image = post_process_to_image(samples)
#image.save("check.png")
#import ipdb; ipdb.set_trace()
device
=
"cuda"
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"/home/patrick/ddpm-lsun-church"
,
timesteps
=
10
)
import
ipdb
;
ipdb
.
set_trace
()
model
=
UNetModel
.
from_pretrained
(
"/home/patrick/ddpm-lsun-church"
).
to
(
device
)
def
noise_like
(
shape
,
device
,
repeat
=
False
):
torch
.
manual_seed
(
0
)
def
repeat_noise
():
next_image
=
scheduler
.
sample_noise
((
1
,
model
.
in_channels
,
model
.
resolution
,
model
.
resolution
),
device
=
device
)
return
torch
.
randn
((
1
,
*
shape
[
1
:]),
device
=
device
).
repeat
(
shape
[
0
],
*
((
1
,)
*
(
len
(
shape
)
-
1
)))
def
noise
():
return
torch
.
randn
(
shape
,
device
=
device
)
return
repeat_noise
()
if
repeat
else
noise
()
# Schedule
def
cosine_beta_schedule
(
timesteps
,
s
=
0.008
):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps
=
timesteps
+
1
x
=
torch
.
linspace
(
0
,
timesteps
,
steps
,
dtype
=
torch
.
float64
)
alphas_cumprod
=
torch
.
cos
(((
x
/
timesteps
)
+
s
)
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
alphas_cumprod
=
alphas_cumprod
/
alphas_cumprod
[
0
]
betas
=
1
-
(
alphas_cumprod
[
1
:]
/
alphas_cumprod
[:
-
1
])
return
torch
.
clip
(
betas
,
0
,
0.999
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
scheduler
))),
total
=
len
(
scheduler
)):
# define coefficients for time step t
clip_image_coeff
=
1
/
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
))
clip_noise_coeff
=
torch
.
sqrt
(
1
/
scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
scheduler
.
get_alpha
(
t
))
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
clip_coeff
=
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
-
1
))
*
scheduler
.
get_beta
(
t
)
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
betas
=
cosine_beta_schedule
(
TIME_STEPS
)
# predict noise residual
alphas
=
1.0
-
betas
with
torch
.
no_grad
():
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
noise_residual
=
model
(
next_image
,
t
)
alphas_cumprod_prev
=
F
.
pad
(
alphas_cumprod
[:
-
1
],
(
1
,
0
),
value
=
1.0
)
posterior_mean_coef1
=
betas
*
torch
.
sqrt
(
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
# compute prev image from noise
posterior_mean_coef2
=
(
1.0
-
alphas_cumprod_prev
)
*
torch
.
sqrt
(
alphas
)
/
(
1.0
-
alphas_cumprod
)
pred_mean
=
clip_image_coeff
*
next_image
-
clip_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
image
=
clip_coeff
*
pred_mean
+
image_coeff
*
next_image
posterior_variance
=
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
# sample variance
posterior_log_variance_clipped
=
torch
.
log
(
posterior_variance
.
clamp
(
min
=
1e-20
)
)
variance
=
scheduler
.
sample_variance
(
t
,
image
.
shape
,
device
=
device
)
# sample previous image
sampled_image
=
image
+
variance
sqrt_recip_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
)
next_image
=
sampled_image
sqrt_recipm1_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
-
1
)
torch
.
manual_seed
(
0
)
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
image
=
post_process_to_image
(
next_image
)
# 1: x_t ~ N(0,1)
image
.
save
(
"example_new.png"
)
x_t
=
dummy_noise
# 2: for t = T, ...., 1 do
for
i
in
reversed
(
range
(
TIME_STEPS
)):
t
=
torch
.
tensor
([
i
])
# 3: z ~ N(0, 1)
noise
=
noise_like
(
x_t
.
shape
,
"cpu"
)
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
pred_noise
=
unet
(
x_t
,
t
)
# pred epsilon_theta
pred_x
=
extract
(
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
extract
(
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
pred_noise
pred_x
.
clamp_
(
-
1.0
,
1.0
)
# pred mean
posterior_mean
=
extract
(
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
pred_x
+
extract
(
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
# --------------------------------------------------------------------#
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance
=
extract
(
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
b
,
*
_
,
device
=
*
x_t
.
shape
,
x_t
.
device
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_t
.
shape
)
-
1
)))
posterior_variance
=
nonzero_mask
*
(
0.5
*
posterior_log_variance
).
exp
()
# --------------------------------------------------------------------#
x_t_1
=
(
posterior_mean
+
posterior_variance
*
noise
).
to
(
torch
.
float32
)
# FOR PATRICK TO VERIFY: make sure manual loop is equal to function
# --------------------------------------------------------------------#
x_t_12
=
diffusion
.
p_sample
(
unet
,
x_t
,
t
,
noise
=
noise
)
assert
(
x_t_1
-
x_t_12
).
abs
().
sum
().
item
()
<
1e-3
# --------------------------------------------------------------------#
x_t
=
x_t_1
models/vision/ddpm/example.py
View file @
fe313730
#!/usr/bin/env python3
#!/usr/bin/env python3
from
diffusers
import
UNetModel
,
GaussianDiffusion
from
modeling_ddpm
import
DDPM
import
tempfile
import
tempfile
from
diffusers
import
GaussianDDPMScheduler
,
UNetModel
from
modeling_ddpm
import
DDPM
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
iffusion
.
from_config
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
DPMScheduler
.
from_config
(
"fusing/ddpm_dummy"
)
# compose Diffusion Pipeline
# compose Diffusion Pipeline
ddpm
=
DDPM
(
unet
,
sampler
)
ddpm
=
DDPM
(
unet
,
sampler
)
...
...
models/vision/ddpm/modeling_ddpm.py
View file @
fe313730
...
@@ -18,7 +18,6 @@ from diffusers import DiffusionPipeline
...
@@ -18,7 +18,6 @@ from diffusers import DiffusionPipeline
class
DDPM
(
DiffusionPipeline
):
class
DDPM
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
gaussian_sampler
):
def
__init__
(
self
,
unet
,
gaussian_sampler
):
super
().
__init__
(
unet
=
unet
,
gaussian_sampler
=
gaussian_sampler
)
super
().
__init__
(
unet
=
unet
,
gaussian_sampler
=
gaussian_sampler
)
...
...
models/vision/ddpm/run_ddpm.py
View file @
fe313730
#!/usr/bin/env python3
#!/usr/bin/env python3
import
torch
import
torch
from
diffusers
import
GaussianD
iffusion
,
UNetModel
from
diffusers
import
GaussianD
DPMScheduler
,
UNetModel
model
=
UNetModel
(
dim
=
64
,
dim_mults
=
(
1
,
2
,
4
,
8
))
model
=
UNetModel
(
dim
=
64
,
dim_mults
=
(
1
,
2
,
4
,
8
))
diffusion
=
GaussianD
iffusion
(
model
,
image_size
=
128
,
timesteps
=
1000
,
loss_type
=
"l1"
)
# number of steps # L1 or L2
diffusion
=
GaussianD
DPMScheduler
(
model
,
image_size
=
128
,
timesteps
=
1000
,
loss_type
=
"l1"
)
# number of steps # L1 or L2
training_images
=
torch
.
randn
(
8
,
3
,
128
,
128
)
# your images need to be normalized from a range of -1 to +1
training_images
=
torch
.
randn
(
8
,
3
,
128
,
128
)
# your images need to be normalized from a range of -1 to +1
loss
=
diffusion
(
training_images
)
loss
=
diffusion
(
training_images
)
...
...
src/diffusers/__init__.py
View file @
fe313730
...
@@ -4,8 +4,7 @@
...
@@ -4,8 +4,7 @@
__version__
=
"0.0.1"
__version__
=
"0.0.1"
from
.modeling_utils
import
PreTrainedModel
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.samplers.gaussian
import
GaussianDiffusion
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.
modeling_utils
import
PreTrainedModel
from
.
schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
src/diffusers/configuration_utils.py
View file @
fe313730
...
@@ -17,10 +17,10 @@
...
@@ -17,10 +17,10 @@
import
copy
import
copy
import
inspect
import
json
import
json
import
os
import
os
import
re
import
re
import
inspect
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
requests
import
HTTPError
from
requests
import
HTTPError
...
@@ -186,6 +186,11 @@ class Config:
...
@@ -186,6 +186,11 @@ class Config:
expected_keys
=
set
(
dict
(
inspect
.
signature
(
cls
.
__init__
).
parameters
).
keys
())
expected_keys
=
set
(
dict
(
inspect
.
signature
(
cls
.
__init__
).
parameters
).
keys
())
expected_keys
.
remove
(
"self"
)
expected_keys
.
remove
(
"self"
)
for
key
in
expected_keys
:
if
key
in
kwargs
:
# overwrite key
config_dict
[
key
]
=
kwargs
.
pop
(
key
)
passed_keys
=
set
(
config_dict
.
keys
())
passed_keys
=
set
(
config_dict
.
keys
())
unused_kwargs
=
kwargs
unused_kwargs
=
kwargs
...
@@ -194,17 +199,16 @@ class Config:
...
@@ -194,17 +199,16 @@ class Config:
if
len
(
expected_keys
-
passed_keys
)
>
0
:
if
len
(
expected_keys
-
passed_keys
)
>
0
:
logger
.
warn
(
logger
.
warn
(
f
"
{
expected_keys
-
passed_keys
}
was not found in config. "
f
"
{
expected_keys
-
passed_keys
}
was not found in config. Values will be initialized to default values."
f
"Values will be initialized to default values."
)
)
return
config_dict
,
unused_kwargs
return
config_dict
,
unused_kwargs
@
classmethod
@
classmethod
def
from_config
(
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
c
ls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
c
onfig_dict
,
unused_kwargs
=
cls
.
get_config_dict
(
):
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
config_dict
,
unused_kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
)
model
=
cls
(
**
config_dict
)
model
=
cls
(
**
config_dict
)
...
...
src/diffusers/modeling_utils.py
View file @
fe313730
...
@@ -24,6 +24,7 @@ from requests import HTTPError
...
@@ -24,6 +24,7 @@ from requests import HTTPError
# CHANGE to diffusers.utils
# CHANGE to diffusers.utils
from
transformers.utils
import
(
from
transformers.utils
import
(
CONFIG_NAME
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
EntryNotFoundError
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RepositoryNotFoundError
,
...
@@ -33,7 +34,6 @@ from transformers.utils import (
...
@@ -33,7 +34,6 @@ from transformers.utils import (
is_offline_mode
,
is_offline_mode
,
is_remote_url
,
is_remote_url
,
logging
,
logging
,
CONFIG_NAME
,
)
)
...
...
src/diffusers/models/unet.py
View file @
fe313730
...
@@ -17,376 +17,337 @@
...
@@ -17,376 +17,337 @@
import
copy
import
copy
import
math
import
math
from
functools
import
partial
from
inspect
import
isfunction
from
pathlib
import
Path
from
pathlib
import
Path
import
torch
import
torch
from
torch
import
einsum
,
nn
from
torch
import
nn
from
torch.cuda.amp
import
GradScaler
,
autocast
from
torch.cuda.amp
import
GradScaler
,
autocast
from
torch.optim
import
Adam
from
torch.optim
import
Adam
from
torch.utils
import
data
from
torch.utils
import
data
from
einops
import
rearrange
from
torchvision
import
transforms
,
utils
from
torchvision
import
utils
,
transforms
from
PIL
import
Image
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
..configuration_utils
import
Config
from
..configuration_utils
import
Config
from
..modeling_utils
import
PreTrainedModel
from
..modeling_utils
import
PreTrainedModel
from
PIL
import
Image
# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
def
exists
(
x
):
return
x
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
cycle
(
dl
):
while
True
:
for
data_dl
in
dl
:
yield
data_dl
def
num_to_groups
(
num
,
divisor
):
groups
=
num
//
divisor
remainder
=
num
%
divisor
arr
=
[
divisor
]
*
groups
if
remainder
>
0
:
arr
.
append
(
remainder
)
return
arr
def
normalize_to_neg_one_to_one
(
img
):
return
img
*
2
-
1
def
unnormalize_to_zero_to_one
(
t
):
return
(
t
+
1
)
*
0.5
# small helper modules
class
EMA
:
def
__init__
(
self
,
beta
):
super
().
__init__
()
self
.
beta
=
beta
def
update_model_average
(
self
,
ma_model
,
current_model
):
for
current_params
,
ma_params
in
zip
(
current_model
.
parameters
(),
ma_model
.
parameters
()):
old_weight
,
up_weight
=
ma_params
.
data
,
current_params
.
data
ma_params
.
data
=
self
.
update_average
(
old_weight
,
up_weight
)
def
update_average
(
self
,
old
,
new
):
if
old
is
None
:
return
new
return
old
*
self
.
beta
+
(
1
-
self
.
beta
)
*
new
class
Residual
(
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
().
__init__
()
self
.
fn
=
fn
def
forward
(
self
,
x
,
*
args
,
**
kwargs
):
return
self
.
fn
(
x
,
*
args
,
**
kwargs
)
+
x
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
class
SinusoidalPosEmb
(
nn
.
Module
):
half_dim
=
embedding_dim
//
2
def
__init__
(
self
,
dim
):
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
super
().
__init__
()
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
self
.
dim
=
dim
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
def
forward
(
self
,
x
):
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
device
=
x
.
device
if
embedding_dim
%
2
==
1
:
# zero pad
half_dim
=
self
.
dim
//
2
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
return
emb
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
)
*
-
emb
)
emb
=
x
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
def
Upsample
(
dim
):
def
nonlinearity
(
x
):
return
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Downsample
(
dim
):
def
Normalize
(
in_channels
):
return
nn
.
Conv2d
(
dim
,
dim
,
4
,
2
,
1
)
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
LayerNorm
(
nn
.
Module
):
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-5
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
super
().
__init__
()
self
.
eps
=
eps
self
.
with_conv
=
with_conv
self
.
g
=
nn
.
Parameter
(
torch
.
ones
(
1
,
dim
,
1
,
1
))
if
self
.
with_conv
:
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
dim
,
1
,
1
)
)
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
var
=
torch
.
var
(
x
,
dim
=
1
,
unbiased
=
False
,
keepdim
=
True
)
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
mean
=
torch
.
mean
(
x
,
dim
=
1
,
keepdim
=
True
)
if
self
.
with_conv
:
return
(
x
-
mean
)
/
(
var
+
self
.
eps
).
sqrt
()
*
self
.
g
+
self
.
b
x
=
self
.
conv
(
x
)
return
x
class
PreNorm
(
nn
.
Module
):
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
fn
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
super
().
__init__
()
self
.
fn
=
fn
self
.
with_conv
=
with_conv
self
.
norm
=
LayerNorm
(
dim
)
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
norm
(
x
)
if
self
.
with_conv
:
return
self
.
fn
(
x
)
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
# building block modules
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
groups
=
8
):
super
().
__init__
()
self
.
proj
=
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
)
self
.
norm
=
nn
.
GroupNorm
(
groups
,
dim_out
)
self
.
act
=
nn
.
SiLU
()
def
forward
(
self
,
x
,
scale_shift
=
None
):
x
=
self
.
proj
(
x
)
x
=
self
.
norm
(
x
)
if
exists
(
scale_shift
):
scale
,
shift
=
scale_shift
x
=
x
*
(
scale
+
1
)
+
shift
x
=
self
.
act
(
x
)
return
x
return
x
class
ResnetBlock
(
nn
.
Module
):
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
*
,
time_emb_dim
=
None
,
groups
=
8
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
super
().
__init__
()
self
.
mlp
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
time_emb_dim
,
dim_out
*
2
))
if
exists
(
time_emb_dim
)
else
None
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
self
.
out_channels
=
out_channels
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
self
.
use_conv_shortcut
=
conv_shortcut
self
.
res_conv
=
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
if
dim
!=
dim_out
else
nn
.
Identity
()
self
.
norm1
=
Normalize
(
in_channels
)
def
forward
(
self
,
x
,
time_emb
=
None
):
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
scale_shift
=
None
self
.
norm2
=
Normalize
(
out_channels
)
if
exists
(
self
.
mlp
)
and
exists
(
time_emb
):
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
time_emb
=
self
.
mlp
(
time_emb
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
time_emb
=
rearrange
(
time_emb
,
"b c -> b c 1 1"
)
if
self
.
in_channels
!=
self
.
out_channels
:
scale_shift
=
time_emb
.
chunk
(
2
,
dim
=
1
)
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
h
=
self
.
block1
(
x
,
scale_shift
=
scale_shift
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
h
=
self
.
block2
(
h
)
return
h
+
self
.
res_conv
(
x
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
class
LinearAttention
(
nn
.
Module
):
h
=
nonlinearity
(
h
)
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
h
=
self
.
conv1
(
h
)
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
super
().
__init__
()
self
.
scale
=
dim_head
**-
0.5
self
.
in_channels
=
in_channels
self
.
heads
=
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
nn
.
Sequential
(
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
),
LayerNorm
(
dim
))
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
h_
=
x
qkv
=
self
.
to_qkv
(
x
).
chunk
(
3
,
dim
=
1
)
h_
=
self
.
norm
(
h_
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
"b (h c) x y -> b h c (x y)"
,
h
=
self
.
heads
),
qkv
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
q
=
q
.
softmax
(
dim
=-
2
)
v
=
self
.
v
(
h_
)
k
=
k
.
softmax
(
dim
=-
1
)
q
=
q
*
self
.
scale
context
=
torch
.
einsum
(
"b h d n, b h e n -> b h d e"
,
k
,
v
)
out
=
torch
.
einsum
(
"b h d e, b h d n -> b h e n"
,
context
,
q
)
# compute attention
out
=
rearrange
(
out
,
"b h c (x y) -> b (h c) x y"
,
h
=
self
.
heads
,
x
=
h
,
y
=
w
)
b
,
c
,
h
,
w
=
q
.
shape
return
self
.
to_out
(
out
)
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
class
Attention
(
nn
.
Module
):
h_
=
self
.
proj_out
(
h_
)
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
().
__init__
()
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
).
chunk
(
3
,
dim
=
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
"b (h c) x y -> b h c (x y)"
,
h
=
self
.
heads
),
qkv
)
q
=
q
*
self
.
scale
sim
=
einsum
(
"b h d i, b h d j -> b h i j"
,
q
,
k
)
sim
=
sim
-
sim
.
amax
(
dim
=-
1
,
keepdim
=
True
).
detach
()
attn
=
sim
.
softmax
(
dim
=-
1
)
out
=
einsum
(
"b h i j, b h d j -> b h i d"
,
attn
,
v
)
return
x
+
h_
out
=
rearrange
(
out
,
"b h (x y) d -> b (h d) x y"
,
x
=
h
,
y
=
w
)
return
self
.
to_out
(
out
)
class
UNetModel
(
PreTrainedModel
,
Config
):
class
UNetModel
(
PreTrainedModel
,
Config
):
def
__init__
(
def
__init__
(
self
,
self
,
dim
=
64
,
ch
=
128
,
dim_mults
=
(
1
,
2
,
4
,
8
),
out_ch
=
3
,
init_dim
=
None
,
ch_mult
=
(
1
,
1
,
2
,
2
,
4
,
4
),
out_dim
=
None
,
num_res_blocks
=
2
,
channels
=
3
,
attn_resolutions
=
(
16
,),
with_time_emb
=
True
,
dropout
=
0.0
,
resnet_block_groups
=
8
,
resamp_with_conv
=
True
,
learned_variance
=
False
,
in_channels
=
3
,
resolution
=
256
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
register
(
dim
=
dim
,
ch
=
ch
,
dim_mults
=
dim_mults
,
out_ch
=
out_ch
,
init_dim
=
init_dim
,
ch_mult
=
ch_mult
,
out_dim
=
out_dim
,
num_res_blocks
=
num_res_blocks
,
channels
=
channels
,
attn_resolutions
=
attn_resolutions
,
with_time_emb
=
with_time_emb
,
dropout
=
dropout
,
resnet_block_groups
=
resnet_block_groups
,
resamp_with_conv
=
resamp_with_conv
,
learned_variance
=
learned_variance
,
in_channels
=
in_channels
,
resolution
=
resolution
,
)
ch_mult
=
tuple
(
ch_mult
)
self
.
ch
=
ch
self
.
temb_ch
=
self
.
ch
*
4
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# timestep embedding
self
.
temb
=
nn
.
Module
()
self
.
temb
.
dense
=
nn
.
ModuleList
(
[
torch
.
nn
.
Linear
(
self
.
ch
,
self
.
temb_ch
),
torch
.
nn
.
Linear
(
self
.
temb_ch
,
self
.
temb_ch
),
]
)
)
init_dim
=
None
out_dim
=
None
channels
=
3
with_time_emb
=
True
resnet_block_groups
=
8
learned_variance
=
False
# determine dimensions
dim_mults
=
dim_mults
dim
=
dim
self
.
channels
=
channels
init_dim
=
default
(
init_dim
,
dim
//
3
*
2
)
self
.
init_conv
=
nn
.
Conv2d
(
channels
,
init_dim
,
7
,
padding
=
3
)
dims
=
[
init_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
block_klass
=
partial
(
ResnetBlock
,
groups
=
resnet_block_groups
)
# time embeddings
if
with_time_emb
:
# downsampling
time_dim
=
dim
*
4
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
time_mlp
=
nn
.
Sequential
(
SinusoidalPosEmb
(
dim
),
nn
.
Linear
(
dim
,
time_dim
),
nn
.
GELU
(),
nn
.
Linear
(
time_dim
,
time_dim
)
curr_res
=
resolution
)
in_ch_mult
=
(
1
,)
+
ch_mult
else
:
self
.
down
=
nn
.
ModuleList
()
time_dim
=
None
for
i_level
in
range
(
self
.
num_resolutions
):
self
.
time_mlp
=
None
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
# layers
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
self
.
downs
=
nn
.
ModuleList
([])
for
i_block
in
range
(
self
.
num_res_blocks
):
self
.
ups
=
nn
.
ModuleList
([])
block
.
append
(
num_resolutions
=
len
(
in_out
)
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
)
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
nn
.
ModuleList
(
[
block_klass
(
dim_in
,
dim_out
,
time_emb_dim
=
time_dim
),
block_klass
(
dim_out
,
dim_out
,
time_emb_dim
=
time_dim
),
Residual
(
PreNorm
(
dim_out
,
LinearAttention
(
dim_out
))),
Downsample
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
]
)
)
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
block_klass
(
mid_dim
,
mid_dim
,
time_emb_dim
=
time_dim
)
self
.
mid_attn
=
Residual
(
PreNorm
(
mid_dim
,
Attention
(
mid_dim
)))
self
.
mid_block2
=
block_klass
(
mid_dim
,
mid_dim
,
time_emb_dim
=
time_dim
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
ups
.
append
(
nn
.
ModuleList
(
[
block_klass
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
time_dim
),
block_klass
(
dim_in
,
dim_in
,
time_emb_dim
=
time_dim
),
Residual
(
PreNorm
(
dim_in
,
LinearAttention
(
dim_in
))),
Upsample
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
]
)
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
default_out_dim
=
channels
*
(
1
if
not
learned_variance
else
2
)
attn
.
append
(
AttnBlock
(
block_in
))
self
.
out_dim
=
default
(
out_dim
,
default_out_dim
)
down
=
nn
.
Module
()
down
.
block
=
block
self
.
final_conv
=
nn
.
Sequential
(
block_klass
(
dim
,
dim
),
nn
.
Conv2d
(
dim
,
self
.
out_dim
,
1
))
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
def
forward
(
self
,
x
,
time
):
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
x
=
self
.
init_conv
(
x
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
t
=
self
.
time_mlp
(
time
)
if
exists
(
self
.
time_mlp
)
else
None
# middle
h
=
[]
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
for
block1
,
block2
,
attn
,
downsample
in
self
.
downs
:
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
x
=
block1
(
x
,
t
)
)
x
=
block2
(
x
,
t
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
x
=
attn
(
x
)
self
.
mid
.
block_2
=
ResnetBlock
(
h
.
append
(
x
)
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
x
=
downsample
(
x
)
)
x
=
self
.
mid_block1
(
x
,
t
)
x
=
self
.
mid_attn
(
x
)
x
=
self
.
mid_block2
(
x
,
t
)
for
block1
,
block2
,
attn
,
upsample
in
self
.
ups
:
x
=
torch
.
cat
((
x
,
h
.
pop
()),
dim
=
1
)
x
=
block1
(
x
,
t
)
x
=
block2
(
x
,
t
)
x
=
attn
(
x
)
x
=
upsample
(
x
)
return
self
.
final_conv
(
x
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
skip_in
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
if
i_block
==
self
.
num_res_blocks
:
skip_in
=
ch
*
in_ch_mult
[
i_level
]
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
+
skip_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
,
t
):
assert
x
.
shape
[
2
]
==
x
.
shape
[
3
]
==
self
.
resolution
if
not
torch
.
is_tensor
(
t
):
t
=
torch
.
tensor
([
t
],
dtype
=
torch
.
long
,
device
=
x
.
device
)
# timestep embedding
temb
=
get_timestep_embedding
(
t
,
self
.
ch
)
temb
=
self
.
temb
.
dense
[
0
](
temb
)
temb
=
nonlinearity
(
temb
)
temb
=
self
.
temb
.
dense
[
1
](
temb
)
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
),
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
# dataset classes
# dataset classes
class
Dataset
(
data
.
Dataset
):
class
Dataset
(
data
.
Dataset
):
def
__init__
(
self
,
folder
,
image_size
,
exts
=
[
"
jpg
"
,
"
jpeg
"
,
"
png
"
]):
def
__init__
(
self
,
folder
,
image_size
,
exts
=
[
'
jpg
'
,
'
jpeg
'
,
'
png
'
]):
super
().
__init__
()
super
().
__init__
()
self
.
folder
=
folder
self
.
folder
=
folder
self
.
image_size
=
image_size
self
.
image_size
=
image_size
self
.
paths
=
[
p
for
ext
in
exts
for
p
in
Path
(
f
"
{
folder
}
"
).
glob
(
f
"
**/*.
{
ext
}
"
)]
self
.
paths
=
[
p
for
ext
in
exts
for
p
in
Path
(
f
'
{
folder
}
'
).
glob
(
f
'
**/*.
{
ext
}
'
)]
self
.
transform
=
transforms
.
Compose
(
self
.
transform
=
transforms
.
Compose
([
[
transforms
.
Resize
(
image_size
),
transforms
.
Resize
(
image_size
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
CenterCrop
(
image_size
),
transforms
.
CenterCrop
(
image_size
),
transforms
.
ToTensor
()
transforms
.
ToTensor
(),
])
]
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
paths
)
return
len
(
self
.
paths
)
...
@@ -398,10 +359,38 @@ class Dataset(data.Dataset):
...
@@ -398,10 +359,38 @@ class Dataset(data.Dataset):
# trainer class
# trainer class
class
EMA
():
def
__init__
(
self
,
beta
):
super
().
__init__
()
self
.
beta
=
beta
def
update_model_average
(
self
,
ma_model
,
current_model
):
for
current_params
,
ma_params
in
zip
(
current_model
.
parameters
(),
ma_model
.
parameters
()):
old_weight
,
up_weight
=
ma_params
.
data
,
current_params
.
data
ma_params
.
data
=
self
.
update_average
(
old_weight
,
up_weight
)
def
update_average
(
self
,
old
,
new
):
if
old
is
None
:
return
new
return
old
*
self
.
beta
+
(
1
-
self
.
beta
)
*
new
class
Trainer
(
object
):
def
cycle
(
dl
):
while
True
:
for
data_dl
in
dl
:
yield
data_dl
def
num_to_groups
(
num
,
divisor
):
groups
=
num
//
divisor
remainder
=
num
%
divisor
arr
=
[
divisor
]
*
groups
if
remainder
>
0
:
arr
.
append
(
remainder
)
return
arr
class
Trainer
(
object
):
def
__init__
(
def
__init__
(
self
,
self
,
diffusion_model
,
diffusion_model
,
...
...
src/diffusers/pipeline_utils.py
View file @
fe313730
...
@@ -14,15 +14,15 @@
...
@@ -14,15 +14,15 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
importlib
import
os
import
os
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
import
importlib
from
.configuration_utils
import
Config
# CHANGE to diffusers.utils
# CHANGE to diffusers.utils
from
transformers.utils
import
logging
from
transformers.utils
import
logging
from
.configuration_utils
import
Config
INDEX_FILE
=
"diffusion_model.pt"
INDEX_FILE
=
"diffusion_model.pt"
...
@@ -33,7 +33,7 @@ logger = logging.get_logger(__name__)
...
@@ -33,7 +33,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES
=
{
LOADABLE_CLASSES
=
{
"diffusers"
:
{
"diffusers"
:
{
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"GaussianD
iffusion
"
:
[
"save_config"
,
"from_config"
],
"GaussianD
DPMScheduler
"
:
[
"save_config"
,
"from_config"
],
},
},
"transformers"
:
{
"transformers"
:
{
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
...
...
src/diffusers/samplers/gaussian.py
deleted
100644 → 0
View file @
3a5c65d5
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
inspect
import
isfunction
from
tqdm
import
tqdm
from
..configuration_utils
import
Config
SAMPLING_CONFIG_NAME
=
"sampler_config.json"
def
exists
(
x
):
return
x
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
cycle
(
dl
):
while
True
:
for
data_dl
in
dl
:
yield
data_dl
def
num_to_groups
(
num
,
divisor
):
groups
=
num
//
divisor
remainder
=
num
%
divisor
arr
=
[
divisor
]
*
groups
if
remainder
>
0
:
arr
.
append
(
remainder
)
return
arr
def
normalize_to_neg_one_to_one
(
img
):
return
img
*
2
-
1
def
unnormalize_to_zero_to_one
(
t
):
return
(
t
+
1
)
*
0.5
# small helper modules
class
EMA
:
def
__init__
(
self
,
beta
):
super
().
__init__
()
self
.
beta
=
beta
def
update_model_average
(
self
,
ma_model
,
current_model
):
for
current_params
,
ma_params
in
zip
(
current_model
.
parameters
(),
ma_model
.
parameters
()):
old_weight
,
up_weight
=
ma_params
.
data
,
current_params
.
data
ma_params
.
data
=
self
.
update_average
(
old_weight
,
up_weight
)
def
update_average
(
self
,
old
,
new
):
if
old
is
None
:
return
new
return
old
*
self
.
beta
+
(
1
-
self
.
beta
)
*
new
# gaussian diffusion trainer class
def
extract
(
a
,
t
,
x_shape
):
b
,
*
_
=
t
.
shape
out
=
a
.
gather
(
-
1
,
t
)
return
out
.
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_shape
)
-
1
)))
def
noise_like
(
shape
,
device
,
repeat
=
False
):
def
repeat_noise
():
return
torch
.
randn
((
1
,
*
shape
[
1
:]),
device
=
device
).
repeat
(
shape
[
0
],
*
((
1
,)
*
(
len
(
shape
)
-
1
)))
def
noise
():
return
torch
.
randn
(
shape
,
device
=
device
)
return
repeat_noise
()
if
repeat
else
noise
()
def
linear_beta_schedule
(
timesteps
):
scale
=
1000
/
timesteps
beta_start
=
scale
*
0.0001
beta_end
=
scale
*
0.02
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
def
cosine_beta_schedule
(
timesteps
,
s
=
0.008
):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps
=
timesteps
+
1
x
=
torch
.
linspace
(
0
,
timesteps
,
steps
,
dtype
=
torch
.
float64
)
alphas_cumprod
=
torch
.
cos
(((
x
/
timesteps
)
+
s
)
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
alphas_cumprod
=
alphas_cumprod
/
alphas_cumprod
[
0
]
betas
=
1
-
(
alphas_cumprod
[
1
:]
/
alphas_cumprod
[:
-
1
])
return
torch
.
clip
(
betas
,
0
,
0.999
)
class
GaussianDiffusion
(
nn
.
Module
,
Config
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
image_size
,
channels
=
3
,
timesteps
=
1000
,
loss_type
=
"l1"
,
objective
=
"pred_noise"
,
beta_schedule
=
"cosine"
,
):
super
().
__init__
()
self
.
register
(
image_size
=
image_size
,
channels
=
channels
,
timesteps
=
timesteps
,
loss_type
=
loss_type
,
objective
=
objective
,
beta_schedule
=
beta_schedule
,
)
self
.
channels
=
channels
self
.
image_size
=
image_size
self
.
objective
=
objective
if
beta_schedule
==
"linear"
:
betas
=
linear_beta_schedule
(
timesteps
)
elif
beta_schedule
==
"cosine"
:
betas
=
cosine_beta_schedule
(
timesteps
)
else
:
raise
ValueError
(
f
"unknown beta schedule
{
beta_schedule
}
"
)
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
F
.
pad
(
alphas_cumprod
[:
-
1
],
(
1
,
0
),
value
=
1.0
)
(
timesteps
,)
=
betas
.
shape
self
.
num_timesteps
=
int
(
timesteps
)
self
.
loss_type
=
loss_type
# helper function to register buffer from float64 to float32
def
register_buffer
(
name
,
val
):
self
.
register_buffer
(
name
,
val
.
to
(
torch
.
float32
))
register_buffer
(
"betas"
,
betas
)
register_buffer
(
"alphas_cumprod"
,
alphas_cumprod
)
register_buffer
(
"alphas_cumprod_prev"
,
alphas_cumprod_prev
)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer
(
"sqrt_alphas_cumprod"
,
torch
.
sqrt
(
alphas_cumprod
))
register_buffer
(
"sqrt_one_minus_alphas_cumprod"
,
torch
.
sqrt
(
1.0
-
alphas_cumprod
))
register_buffer
(
"log_one_minus_alphas_cumprod"
,
torch
.
log
(
1.0
-
alphas_cumprod
))
register_buffer
(
"sqrt_recip_alphas_cumprod"
,
torch
.
sqrt
(
1.0
/
alphas_cumprod
))
register_buffer
(
"sqrt_recipm1_alphas_cumprod"
,
torch
.
sqrt
(
1.0
/
alphas_cumprod
-
1
))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance
=
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer
(
"posterior_variance"
,
posterior_variance
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer
(
"posterior_log_variance_clipped"
,
torch
.
log
(
posterior_variance
.
clamp
(
min
=
1e-20
)))
register_buffer
(
"posterior_mean_coef1"
,
betas
*
torch
.
sqrt
(
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
))
register_buffer
(
"posterior_mean_coef2"
,
(
1.0
-
alphas_cumprod_prev
)
*
torch
.
sqrt
(
alphas
)
/
(
1.0
-
alphas_cumprod
)
)
def
predict_start_from_noise
(
self
,
x_t
,
t
,
noise
):
return
(
extract
(
self
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
extract
(
self
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
noise
)
def
q_posterior
(
self
,
x_start
,
x_t
,
t
):
posterior_mean
=
(
extract
(
self
.
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
x_start
+
extract
(
self
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
)
posterior_variance
=
extract
(
self
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_log_variance_clipped
=
extract
(
self
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
return
posterior_mean
,
posterior_variance
,
posterior_log_variance_clipped
def
p_mean_variance
(
self
,
model
,
x
,
t
,
clip_denoised
:
bool
):
model_output
=
model
(
x
,
t
)
if
self
.
objective
==
"pred_noise"
:
x_start
=
self
.
predict_start_from_noise
(
x
,
t
=
t
,
noise
=
model_output
)
elif
self
.
objective
==
"pred_x0"
:
x_start
=
model_output
else
:
raise
ValueError
(
f
"unknown objective
{
self
.
objective
}
"
)
if
clip_denoised
:
x_start
.
clamp_
(
-
1.0
,
1.0
)
model_mean
,
posterior_variance
,
posterior_log_variance
=
self
.
q_posterior
(
x_start
=
x_start
,
x_t
=
x
,
t
=
t
)
return
model_mean
,
posterior_variance
,
posterior_log_variance
@
torch
.
no_grad
()
def
p_sample
(
self
,
model
,
x
,
t
,
noise
=
None
,
clip_denoised
=
True
,
repeat_noise
=
False
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
model_mean
,
_
,
model_log_variance
=
self
.
p_mean_variance
(
model
=
model
,
x
=
x
,
t
=
t
,
clip_denoised
=
clip_denoised
)
if
noise
is
None
:
noise
=
noise_like
(
x
.
shape
,
device
,
repeat_noise
)
# no noise when t == 0
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x
.
shape
)
-
1
)))
result
=
model_mean
+
nonzero_mask
*
(
0.5
*
model_log_variance
).
exp
()
*
noise
return
result
@
torch
.
no_grad
()
def
p_sample_loop
(
self
,
model
,
shape
):
device
=
self
.
betas
.
device
b
=
shape
[
0
]
img
=
torch
.
randn
(
shape
,
device
=
device
)
for
i
in
tqdm
(
reversed
(
range
(
0
,
self
.
num_timesteps
)),
desc
=
"sampling loop time step"
,
total
=
self
.
num_timesteps
):
img
=
self
.
p_sample
(
model
,
img
,
torch
.
full
((
b
,),
i
,
device
=
device
,
dtype
=
torch
.
long
))
img
=
unnormalize_to_zero_to_one
(
img
)
return
img
@
torch
.
no_grad
()
def
sample
(
self
,
model
,
batch_size
=
16
):
image_size
=
self
.
image_size
channels
=
self
.
channels
return
self
.
p_sample_loop
(
model
,
(
batch_size
,
channels
,
image_size
,
image_size
))
@
torch
.
no_grad
()
def
interpolate
(
self
,
model
,
x1
,
x2
,
t
=
None
,
lam
=
0.5
):
b
,
*
_
,
device
=
*
x1
.
shape
,
x1
.
device
t
=
default
(
t
,
self
.
num_timesteps
-
1
)
assert
x1
.
shape
==
x2
.
shape
t_batched
=
torch
.
stack
([
torch
.
tensor
(
t
,
device
=
device
)]
*
b
)
xt1
,
xt2
=
map
(
lambda
x
:
self
.
q_sample
(
x
,
t
=
t_batched
),
(
x1
,
x2
))
img
=
(
1
-
lam
)
*
xt1
+
lam
*
xt2
for
i
in
tqdm
(
reversed
(
range
(
0
,
t
)),
desc
=
"interpolation sample time step"
,
total
=
t
):
img
=
self
.
p_sample
(
model
,
img
,
torch
.
full
((
b
,),
i
,
device
=
device
,
dtype
=
torch
.
long
))
return
img
def
q_sample
(
self
,
x_start
,
t
,
noise
=
None
):
noise
=
default
(
noise
,
lambda
:
torch
.
randn_like
(
x_start
))
return
(
extract
(
self
.
sqrt_alphas_cumprod
,
t
,
x_start
.
shape
)
*
x_start
+
extract
(
self
.
sqrt_one_minus_alphas_cumprod
,
t
,
x_start
.
shape
)
*
noise
)
@
property
def
loss_fn
(
self
):
if
self
.
loss_type
==
"l1"
:
return
F
.
l1_loss
elif
self
.
loss_type
==
"l2"
:
return
F
.
mse_loss
else
:
raise
ValueError
(
f
"invalid loss type
{
self
.
loss_type
}
"
)
def
p_losses
(
self
,
model
,
x_start
,
t
,
noise
=
None
):
b
,
c
,
h
,
w
=
x_start
.
shape
noise
=
default
(
noise
,
lambda
:
torch
.
randn_like
(
x_start
))
x
=
self
.
q_sample
(
x_start
=
x_start
,
t
=
t
,
noise
=
noise
)
model_out
=
model
(
x
,
t
)
if
self
.
objective
==
"pred_noise"
:
target
=
noise
elif
self
.
objective
==
"pred_x0"
:
target
=
x_start
else
:
raise
ValueError
(
f
"unknown objective
{
self
.
objective
}
"
)
loss
=
self
.
loss_fn
(
model_out
,
target
)
return
loss
def
forward
(
self
,
model
,
img
,
*
args
,
**
kwargs
):
b
,
_
,
h
,
w
,
device
,
img_size
,
=
(
*
img
.
shape
,
img
.
device
,
self
.
image_size
,
)
assert
h
==
img_size
and
w
==
img_size
,
f
"height and width of image must be
{
img_size
}
"
t
=
torch
.
randint
(
0
,
self
.
num_timesteps
,
(
b
,),
device
=
device
).
long
()
img
=
normalize_to_neg_one_to_one
(
img
)
return
self
.
p_losses
(
model
,
img
,
t
,
*
args
,
**
kwargs
)
src/diffusers/s
amp
lers/__init__.py
→
src/diffusers/s
chedu
lers/__init__.py
View file @
fe313730
...
@@ -16,4 +16,4 @@
...
@@ -16,4 +16,4 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.gaussian
import
GaussianD
iffusion
from
.gaussian
_ddpm
import
GaussianD
DPMScheduler
src/diffusers/schedulers/gaussian_ddpm.py
0 → 100644
View file @
fe313730
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
torch
import
nn
from
..configuration_utils
import
Config
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
def
linear_beta_schedule
(
timesteps
,
beta_start
,
beta_end
):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
class
GaussianDDPMScheduler
(
nn
.
Module
,
Config
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
timesteps
=
1000
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
variance_type
=
"fixed_small"
,
):
super
().
__init__
()
self
.
register
(
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
variance_type
=
variance_type
,
)
self
.
num_timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"linear"
:
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
torch
.
nn
.
functional
.
pad
(
alphas_cumprod
[:
-
1
],
(
1
,
0
),
value
=
1.0
)
variance
=
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
if
variance_type
==
"fixed_small"
:
log_variance
=
torch
.
log
(
variance
.
clamp
(
min
=
1e-20
))
elif
variance_type
==
"fixed_large"
:
log_variance
=
torch
.
log
(
torch
.
cat
([
variance
[
1
:
2
],
betas
[
1
:]],
dim
=
0
))
self
.
register_buffer
(
"betas"
,
betas
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"alphas"
,
alphas
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"alphas_cumprod"
,
alphas_cumprod
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"log_variance"
,
log_variance
.
to
(
torch
.
float32
))
def
get_alpha
(
self
,
time_step
):
return
self
.
alphas
[
time_step
]
def
get_beta
(
self
,
time_step
):
return
self
.
betas
[
time_step
]
def
get_alpha_prod
(
self
,
time_step
):
if
time_step
<
0
:
return
torch
.
tensor
(
1.0
)
return
self
.
alphas_cumprod
[
time_step
]
def
sample_variance
(
self
,
time_step
,
shape
,
device
,
generator
=
None
):
variance
=
self
.
log_variance
[
time_step
]
nonzero_mask
=
torch
.
tensor
([
1
-
(
time_step
==
0
)],
device
=
device
).
float
()[
None
,
:].
repeat
(
shape
[
0
],
1
)
noise
=
self
.
sample_noise
(
shape
,
device
=
device
,
generator
=
generator
)
sampled_variance
=
nonzero_mask
*
(
0.5
*
variance
).
exp
()
sampled_variance
=
sampled_variance
*
noise
return
sampled_variance
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
# always sample on CPU to be deterministic
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
return
self
.
num_timesteps
tests/test_modeling_utils.py
View file @
fe313730
...
@@ -16,13 +16,45 @@
...
@@ -16,13 +16,45 @@
import
random
import
random
import
tempfile
import
tempfile
import
unittest
import
unittest
import
os
from
distutils.util
import
strtobool
import
torch
import
torch
from
diffusers
import
GaussianD
iffusion
,
UNetModel
from
diffusers
import
GaussianD
DPMScheduler
,
UNetModel
global_rng
=
random
.
Random
()
global_rng
=
random
.
Random
()
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
def
parse_flag_from_env
(
key
,
default
=
False
):
try
:
value
=
os
.
environ
[
key
]
except
KeyError
:
# KEY isn't set, default to `default`.
_value
=
default
else
:
# KEY is set, convert it to True or False.
try
:
_value
=
strtobool
(
value
)
except
ValueError
:
# More values are supported, but let's keep the message simple.
raise
ValueError
(
f
"If set,
{
key
}
must be yes or no."
)
return
_value
_run_slow_tests
=
parse_flag_from_env
(
"RUN_SLOW"
,
default
=
False
)
def
slow
(
test_case
):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return
unittest
.
skipUnless
(
_run_slow_tests
,
"test is slow"
)(
test_case
)
def
floats_tensor
(
shape
,
scale
=
1.0
,
rng
=
None
,
name
=
None
):
def
floats_tensor
(
shape
,
scale
=
1.0
,
rng
=
None
,
name
=
None
):
...
@@ -54,7 +86,7 @@ class ModelTesterMixin(unittest.TestCase):
...
@@ -54,7 +86,7 @@ class ModelTesterMixin(unittest.TestCase):
return
(
noise
,
time_step
)
return
(
noise
,
time_step
)
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
model
=
UNetModel
(
dim
=
8
,
dim
_mult
s
=
(
1
,
2
),
resnet_block_groups
=
2
)
model
=
UNetModel
(
ch
=
32
,
ch
_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
3
2
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
model
.
save_pretrained
(
tmpdirname
)
...
@@ -77,30 +109,93 @@ class ModelTesterMixin(unittest.TestCase):
...
@@ -77,30 +109,93 @@ class ModelTesterMixin(unittest.TestCase):
class
SamplerTesterMixin
(
unittest
.
TestCase
):
class
SamplerTesterMixin
(
unittest
.
TestCase
):
@
property
@
slow
def
dummy_model
(
self
):
def
test_sample
(
self
):
return
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
6694729458485568
)
def
test_from_pretrained_save_pretrained
(
self
):
sampler
=
GaussianDiffusion
(
image_size
=
128
,
timesteps
=
3
,
loss_type
=
"l1"
)
# 1. Load models
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/ddpm-lsun-church"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm-lsun-church"
).
to
(
torch_device
)
sampler
.
save_config
(
tmpdirname
)
new_sampler
=
GaussianDiffusion
.
from_config
(
tmpdirname
,
return_unused
=
False
)
# 2. Sample gaussian noise
image
=
scheduler
.
sample_noise
((
1
,
model
.
in_channels
,
model
.
resolution
,
model
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
model
=
self
.
dummy_model
# 3. Denoise
torch
.
manual_seed
(
0
)
for
t
in
reversed
(
range
(
len
(
scheduler
))):
sampled_out
=
sampler
.
sample
(
model
,
batch_size
=
1
)
# i) define coefficients for time step t
clip_image_coeff
=
1
/
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
))
clip_noise_coeff
=
torch
.
sqrt
(
1
/
scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
scheduler
.
get_alpha
(
t
))
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
clip_coeff
=
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
-
1
))
*
scheduler
.
get_beta
(
t
)
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
# ii) predict noise residual
with
torch
.
no_grad
():
noise_residual
=
model
(
image
,
t
)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean
=
clip_image_coeff
*
image
-
clip_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
prev_image
=
clip_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
prev_variance
=
scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
# Note: The better test is to simply check with the following lines of code that the image is sensible
# import PIL
# import numpy as np
# image_processed = image.cpu().permute(0, 2, 3, 1)
# image_processed = (image_processed + 1.0) * 127.5
# image_processed = image_processed.numpy().astype(np.uint8)
# image_pil = PIL.Image.fromarray(image_processed[0])
# image_pil.save("test.png")
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
(
image_slice
-
torch
.
tensor
([[
-
0.0598
,
-
0.0611
,
-
0.0506
],
[
-
0.0726
,
0.0220
,
0.0103
],
[
-
0.0723
,
-
0.1310
,
-
0.2458
]])).
abs
().
sum
()
<
1e-3
def
test_sample_fast
(
self
):
# 1. Load models
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
6694729458485568
)
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/ddpm-lsun-church"
,
timesteps
=
10
)
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm-lsun-church"
).
to
(
torch_device
)
# 2. Sample gaussian noise
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
sampled_out_new
=
new_sampler
.
sample
(
model
,
batch_size
=
1
)
image
=
scheduler
.
sample_noise
((
1
,
model
.
in_channels
,
model
.
resolution
,
model
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
assert
(
sampled_out
-
sampled_out_new
).
abs
().
sum
()
<
1e-5
,
"Samplers don't give the same output"
# 3. Denoise
for
t
in
reversed
(
range
(
len
(
scheduler
))):
def
test_from_pretrained_hub
(
self
):
# i) define coefficients for time step t
sampler
=
GaussianDiffusion
.
from_config
(
"fusing/ddpm_dummy"
)
clip_image_coeff
=
1
/
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
))
model
=
self
.
dummy_model
clip_noise_coeff
=
torch
.
sqrt
(
1
/
scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
scheduler
.
get_alpha
(
t
))
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
sampled_out
=
sampler
.
sample
(
model
,
batch_size
=
1
)
clip_coeff
=
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
-
1
))
*
scheduler
.
get_beta
(
t
)
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
assert
sampled_out
is
not
None
,
"Make sure output is not None"
# ii) predict noise residual
with
torch
.
no_grad
():
noise_residual
=
model
(
image
,
t
)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean
=
clip_image_coeff
*
image
-
clip_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
prev_image
=
clip_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
prev_variance
=
scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
(
image_slice
-
torch
.
tensor
([[
0.1746
,
0.5125
,
-
0.7920
],
[
-
0.5734
,
-
0.2910
,
-
0.1984
],
[
0.4090
,
-
0.7740
,
-
0.3941
]])).
abs
().
sum
()
<
1e-3
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment