Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
fe313730
Commit
fe313730
authored
Jun 06, 2022
by
Patrick von Platen
Browse files
improve
parent
3a5c65d5
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
691 additions
and
760 deletions
+691
-760
README.md
README.md
+4
-4
examples/sample_loop.py
examples/sample_loop.py
+142
-84
models/vision/ddpm/example.py
models/vision/ddpm/example.py
+5
-3
models/vision/ddpm/modeling_ddpm.py
models/vision/ddpm/modeling_ddpm.py
+0
-1
models/vision/ddpm/run_ddpm.py
models/vision/ddpm/run_ddpm.py
+2
-2
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-3
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+11
-7
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+1
-1
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+298
-309
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+4
-4
src/diffusers/samplers/gaussian.py
src/diffusers/samplers/gaussian.py
+0
-313
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-1
src/diffusers/schedulers/gaussian_ddpm.py
src/diffusers/schedulers/gaussian_ddpm.py
+98
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+123
-28
No files found.
README.md
View file @
fe313730
...
...
@@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th
Example:
```
python
from
diffusers
import
UNetModel
,
GaussianD
iffusion
from
diffusers
import
UNetModel
,
GaussianD
DPMScheduler
import
torch
# 1. Load model
...
...
@@ -40,7 +40,7 @@ time_step = torch.tensor([10])
image
=
unet
(
dummy_noise
,
time_step
)
# 3. Load sampler
sampler
=
GaussianD
iffusion
.
from_config
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
DPMScheduler
.
from_config
(
"fusing/ddpm_dummy"
)
# 4. Sample image from sampler passing the model
image
=
sampler
.
sample
(
model
,
batch_size
=
1
)
...
...
@@ -54,12 +54,12 @@ print(image)
Example:
```
python
from
diffusers
import
UNetModel
,
GaussianD
iffusion
from
diffusers
import
UNetModel
,
GaussianD
DPMScheduler
from
modeling_ddpm
import
DDPM
import
tempfile
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
iffusion
.
from_config
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
DPMScheduler
.
from_config
(
"fusing/ddpm_dummy"
)
# compose Diffusion Pipeline
ddpm
=
DDPM
(
unet
,
sampler
)
...
...
examples/sample_loop.py
View file @
fe313730
#!/usr/bin/env python3
from
diffusers
import
UNetModel
,
GaussianD
iffusion
from
diffusers
import
UNetModel
,
GaussianD
DPMScheduler
import
torch
import
torch.nn.functional
as
F
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
diffusion
=
GaussianDiffusion
.
from_config
(
"fusing/ddpm_dummy"
)
import
numpy
as
np
import
PIL.Image
import
tqdm
#torch_device = "cuda"
#
#unet = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church")
#unet.to(torch_device)
#
#TIME_STEPS = 10
#
#scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=TIME_STEPS)
#
#diffusion_config = {
# "beta_start": 0.0001,
# "beta_end": 0.02,
# "num_diffusion_timesteps": TIME_STEPS,
#}
#
# 2. Do one denoising step with model
batch_size
,
num_channels
,
height
,
width
=
1
,
3
,
32
,
32
dummy_noise
=
torch
.
ones
((
batch_size
,
num_channels
,
height
,
width
))
TIME_STEPS
=
10
#batch_size, num_channels, height, width = 1, 3, 256, 256
#
#torch.manual_seed(0)
#noise_image = torch.randn(batch_size, num_channels, height, width, device="cuda")
#
#
# Helper
def
extract
(
a
,
t
,
x_shape
):
b
,
*
_
=
t
.
shape
out
=
a
.
gather
(
-
1
,
t
)
return
out
.
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_shape
)
-
1
)))
#def noise_like(shape, device, repeat=False):
# def repeat_noise():
# return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
#
# def noise():
# return torch.randn(shape, device=device)
#
# return repeat_noise() if repeat else noise()
#
#
#betas = np.linspace(diffusion_config["beta_start"], diffusion_config["beta_end"], diffusion_config["num_diffusion_timesteps"], dtype=np.float64)
#betas = torch.tensor(betas, device=torch_device)
#alphas = 1.0 - betas
#
#alphas_cumprod = torch.cumprod(alphas, axis=0)
#alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
#
#posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
#
#posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
#
#
#sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
#sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
#
#
#noise_coeff = (1 - alphas) / torch.sqrt(1 - alphas_cumprod)
#coeff = 1 / torch.sqrt(alphas)
def
real_fn
():
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
x_t
=
noise_image
# 2: for t = T, ...., 1 do
for
i
in
reversed
(
range
(
TIME_STEPS
)):
t
=
torch
.
tensor
([
i
]).
to
(
torch_device
)
# 3: z ~ N(0, 1)
noise
=
noise_like
(
x_t
.
shape
,
torch_device
)
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
with
torch
.
no_grad
():
pred_noise
=
unet
(
x_t
,
t
)
# pred epsilon_theta
# pred_x = sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * pred_noise
# pred_x.clamp_(-1.0, 1.0)
# pred mean
# posterior_mean = posterior_mean_coef1[t] * pred_x + posterior_mean_coef2[t] * x_t
# --------------------------------------------------------------------#
posterior_mean
=
coeff
[
t
]
*
(
x_t
-
noise_coeff
[
t
]
*
pred_noise
)
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance
=
posterior_log_variance_clipped
[
t
]
b
,
*
_
,
device
=
*
x_t
.
shape
,
x_t
.
device
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_t
.
shape
)
-
1
)))
posterior_variance
=
nonzero_mask
*
(
0.5
*
posterior_log_variance
).
exp
()
# --------------------------------------------------------------------#
x_t_1
=
(
posterior_mean
+
posterior_variance
*
noise
).
to
(
torch
.
float32
)
x_t
=
x_t_1
print
(
x_t
.
abs
().
sum
())
def
post_process_to_image
(
x_t
):
image
=
x_t
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image
=
(
image
+
1.0
)
*
127.5
image
=
image
.
numpy
().
astype
(
np
.
uint8
)
return
PIL
.
Image
.
fromarray
(
image
[
0
])
from
pytorch_diffusion
import
Diffusion
#diffusion = Diffusion.from_pretrained("lsun_church")
#samples = diffusion.denoise(1)
#
#image = post_process_to_image(samples)
#image.save("check.png")
#import ipdb; ipdb.set_trace()
device
=
"cuda"
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"/home/patrick/ddpm-lsun-church"
,
timesteps
=
10
)
import
ipdb
;
ipdb
.
set_trace
()
model
=
UNetModel
.
from_pretrained
(
"/home/patrick/ddpm-lsun-church"
).
to
(
device
)
def
noise_like
(
shape
,
device
,
repeat
=
False
):
def
repeat_noise
():
return
torch
.
randn
((
1
,
*
shape
[
1
:]),
device
=
device
).
repeat
(
shape
[
0
],
*
((
1
,)
*
(
len
(
shape
)
-
1
)))
def
noise
():
return
torch
.
randn
(
shape
,
device
=
device
)
return
repeat_noise
()
if
repeat
else
noise
()
# Schedule
def
cosine_beta_schedule
(
timesteps
,
s
=
0.008
):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps
=
timesteps
+
1
x
=
torch
.
linspace
(
0
,
timesteps
,
steps
,
dtype
=
torch
.
float64
)
alphas_cumprod
=
torch
.
cos
(((
x
/
timesteps
)
+
s
)
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
alphas_cumprod
=
alphas_cumprod
/
alphas_cumprod
[
0
]
betas
=
1
-
(
alphas_cumprod
[
1
:]
/
alphas_cumprod
[:
-
1
])
return
torch
.
clip
(
betas
,
0
,
0.999
)
torch
.
manual_seed
(
0
)
next_image
=
scheduler
.
sample_noise
((
1
,
model
.
in_channels
,
model
.
resolution
,
model
.
resolution
),
device
=
device
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
scheduler
))),
total
=
len
(
scheduler
)):
# define coefficients for time step t
clip_image_coeff
=
1
/
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
))
clip_noise_coeff
=
torch
.
sqrt
(
1
/
scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
scheduler
.
get_alpha
(
t
))
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
clip_coeff
=
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
-
1
))
*
scheduler
.
get_beta
(
t
)
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
betas
=
cosine_beta_schedule
(
TIME_STEPS
)
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
F
.
pad
(
alphas_cumprod
[:
-
1
],
(
1
,
0
),
value
=
1.0
)
# predict noise residual
with
torch
.
no_grad
():
noise_residual
=
model
(
next_image
,
t
)
posterior_mean_coef1
=
betas
*
torch
.
sqrt
(
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
posterior_mean_coef2
=
(
1.0
-
alphas_cumprod_prev
)
*
torch
.
sqrt
(
alphas
)
/
(
1.0
-
alphas_cumprod
)
# compute prev image from noise
pred_mean
=
clip_image_coeff
*
next_image
-
clip_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
image
=
clip_coeff
*
pred_mean
+
image_coeff
*
next_image
posterior_variance
=
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
posterior_log_variance_clipped
=
torch
.
log
(
posterior_variance
.
clamp
(
min
=
1e-20
)
)
# sample variance
variance
=
scheduler
.
sample_variance
(
t
,
image
.
shape
,
device
=
device
)
# sample previous image
sampled_image
=
image
+
variance
sqrt_recip_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
)
sqrt_recipm1_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
-
1
)
next_image
=
sampled_image
torch
.
manual_seed
(
0
)
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
x_t
=
dummy_noise
# 2: for t = T, ...., 1 do
for
i
in
reversed
(
range
(
TIME_STEPS
)):
t
=
torch
.
tensor
([
i
])
# 3: z ~ N(0, 1)
noise
=
noise_like
(
x_t
.
shape
,
"cpu"
)
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
pred_noise
=
unet
(
x_t
,
t
)
# pred epsilon_theta
pred_x
=
extract
(
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
extract
(
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
pred_noise
pred_x
.
clamp_
(
-
1.0
,
1.0
)
# pred mean
posterior_mean
=
extract
(
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
pred_x
+
extract
(
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
# --------------------------------------------------------------------#
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance
=
extract
(
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
b
,
*
_
,
device
=
*
x_t
.
shape
,
x_t
.
device
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_t
.
shape
)
-
1
)))
posterior_variance
=
nonzero_mask
*
(
0.5
*
posterior_log_variance
).
exp
()
# --------------------------------------------------------------------#
x_t_1
=
(
posterior_mean
+
posterior_variance
*
noise
).
to
(
torch
.
float32
)
# FOR PATRICK TO VERIFY: make sure manual loop is equal to function
# --------------------------------------------------------------------#
x_t_12
=
diffusion
.
p_sample
(
unet
,
x_t
,
t
,
noise
=
noise
)
assert
(
x_t_1
-
x_t_12
).
abs
().
sum
().
item
()
<
1e-3
# --------------------------------------------------------------------#
x_t
=
x_t_1
image
=
post_process_to_image
(
next_image
)
image
.
save
(
"example_new.png"
)
models/vision/ddpm/example.py
View file @
fe313730
#!/usr/bin/env python3
from
diffusers
import
UNetModel
,
GaussianDiffusion
from
modeling_ddpm
import
DDPM
import
tempfile
from
diffusers
import
GaussianDDPMScheduler
,
UNetModel
from
modeling_ddpm
import
DDPM
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
iffusion
.
from_config
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
DPMScheduler
.
from_config
(
"fusing/ddpm_dummy"
)
# compose Diffusion Pipeline
ddpm
=
DDPM
(
unet
,
sampler
)
...
...
models/vision/ddpm/modeling_ddpm.py
View file @
fe313730
...
...
@@ -18,7 +18,6 @@ from diffusers import DiffusionPipeline
class
DDPM
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
gaussian_sampler
):
super
().
__init__
(
unet
=
unet
,
gaussian_sampler
=
gaussian_sampler
)
...
...
models/vision/ddpm/run_ddpm.py
View file @
fe313730
#!/usr/bin/env python3
import
torch
from
diffusers
import
GaussianD
iffusion
,
UNetModel
from
diffusers
import
GaussianD
DPMScheduler
,
UNetModel
model
=
UNetModel
(
dim
=
64
,
dim_mults
=
(
1
,
2
,
4
,
8
))
diffusion
=
GaussianD
iffusion
(
model
,
image_size
=
128
,
timesteps
=
1000
,
loss_type
=
"l1"
)
# number of steps # L1 or L2
diffusion
=
GaussianD
DPMScheduler
(
model
,
image_size
=
128
,
timesteps
=
1000
,
loss_type
=
"l1"
)
# number of steps # L1 or L2
training_images
=
torch
.
randn
(
8
,
3
,
128
,
128
)
# your images need to be normalized from a range of -1 to +1
loss
=
diffusion
(
training_images
)
...
...
src/diffusers/__init__.py
View file @
fe313730
...
...
@@ -4,8 +4,7 @@
__version__
=
"0.0.1"
from
.modeling_utils
import
PreTrainedModel
from
.models.unet
import
UNetModel
from
.samplers.gaussian
import
GaussianDiffusion
from
.pipeline_utils
import
DiffusionPipeline
from
.
modeling_utils
import
PreTrainedModel
from
.
schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
src/diffusers/configuration_utils.py
View file @
fe313730
...
...
@@ -17,10 +17,10 @@
import
copy
import
inspect
import
json
import
os
import
re
import
inspect
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
requests
import
HTTPError
...
...
@@ -186,6 +186,11 @@ class Config:
expected_keys
=
set
(
dict
(
inspect
.
signature
(
cls
.
__init__
).
parameters
).
keys
())
expected_keys
.
remove
(
"self"
)
for
key
in
expected_keys
:
if
key
in
kwargs
:
# overwrite key
config_dict
[
key
]
=
kwargs
.
pop
(
key
)
passed_keys
=
set
(
config_dict
.
keys
())
unused_kwargs
=
kwargs
...
...
@@ -194,17 +199,16 @@ class Config:
if
len
(
expected_keys
-
passed_keys
)
>
0
:
logger
.
warn
(
f
"
{
expected_keys
-
passed_keys
}
was not found in config. "
f
"Values will be initialized to default values."
f
"
{
expected_keys
-
passed_keys
}
was not found in config. Values will be initialized to default values."
)
return
config_dict
,
unused_kwargs
@
classmethod
def
from_config
(
c
ls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
,
unused_kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
c
onfig_dict
,
unused_kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
model
=
cls
(
**
config_dict
)
...
...
src/diffusers/modeling_utils.py
View file @
fe313730
...
...
@@ -24,6 +24,7 @@ from requests import HTTPError
# CHANGE to diffusers.utils
from
transformers.utils
import
(
CONFIG_NAME
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
EntryNotFoundError
,
RepositoryNotFoundError
,
...
...
@@ -33,7 +34,6 @@ from transformers.utils import (
is_offline_mode
,
is_remote_url
,
logging
,
CONFIG_NAME
,
)
...
...
src/diffusers/models/unet.py
View file @
fe313730
...
...
@@ -17,376 +17,337 @@
import
copy
import
math
from
functools
import
partial
from
inspect
import
isfunction
from
pathlib
import
Path
import
torch
from
torch
import
einsum
,
nn
from
torch
import
nn
from
torch.cuda.amp
import
GradScaler
,
autocast
from
torch.optim
import
Adam
from
torch.utils
import
data
from
einops
import
rearrange
from
torchvision
import
utils
,
transforms
from
torchvision
import
transforms
,
utils
from
PIL
import
Image
from
tqdm
import
tqdm
from
..configuration_utils
import
Config
from
..modeling_utils
import
PreTrainedModel
from
PIL
import
Image
# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
def
exists
(
x
):
return
x
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
cycle
(
dl
):
while
True
:
for
data_dl
in
dl
:
yield
data_dl
def
num_to_groups
(
num
,
divisor
):
groups
=
num
//
divisor
remainder
=
num
%
divisor
arr
=
[
divisor
]
*
groups
if
remainder
>
0
:
arr
.
append
(
remainder
)
return
arr
def
normalize_to_neg_one_to_one
(
img
):
return
img
*
2
-
1
def
unnormalize_to_zero_to_one
(
t
):
return
(
t
+
1
)
*
0.5
# small helper modules
class
EMA
:
def
__init__
(
self
,
beta
):
super
().
__init__
()
self
.
beta
=
beta
def
update_model_average
(
self
,
ma_model
,
current_model
):
for
current_params
,
ma_params
in
zip
(
current_model
.
parameters
(),
ma_model
.
parameters
()):
old_weight
,
up_weight
=
ma_params
.
data
,
current_params
.
data
ma_params
.
data
=
self
.
update_average
(
old_weight
,
up_weight
)
def
update_average
(
self
,
old
,
new
):
if
old
is
None
:
return
new
return
old
*
self
.
beta
+
(
1
-
self
.
beta
)
*
new
class
Residual
(
nn
.
Module
):
def
__init__
(
self
,
fn
):
super
().
__init__
()
self
.
fn
=
fn
def
forward
(
self
,
x
,
*
args
,
**
kwargs
):
return
self
.
fn
(
x
,
*
args
,
**
kwargs
)
+
x
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert
len
(
timesteps
.
shape
)
==
1
class
SinusoidalPosEmb
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
self
.
dim
=
dim
def
forward
(
self
,
x
):
device
=
x
.
device
half_dim
=
self
.
dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
device
=
device
)
*
-
emb
)
emb
=
x
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
((
emb
.
sin
(),
emb
.
cos
()),
dim
=-
1
)
return
emb
half_dim
=
embedding_dim
//
2
emb
=
math
.
log
(
10000
)
/
(
half_dim
-
1
)
emb
=
torch
.
exp
(
torch
.
arange
(
half_dim
,
dtype
=
torch
.
float32
)
*
-
emb
)
emb
=
emb
.
to
(
device
=
timesteps
.
device
)
emb
=
timesteps
.
float
()[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
torch
.
sin
(
emb
),
torch
.
cos
(
emb
)],
dim
=
1
)
if
embedding_dim
%
2
==
1
:
# zero pad
emb
=
torch
.
nn
.
functional
.
pad
(
emb
,
(
0
,
1
,
0
,
0
))
return
emb
def
Upsample
(
dim
):
return
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
nonlinearity
(
x
):
# swish
return
x
*
torch
.
sigmoid
(
x
)
def
Downsample
(
dim
):
return
nn
.
Conv2d
(
dim
,
dim
,
4
,
2
,
1
)
def
Normalize
(
in_channels
):
return
torch
.
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
eps
=
1e-5
):
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
eps
=
eps
self
.
g
=
nn
.
Parameter
(
torch
.
ones
(
1
,
dim
,
1
,
1
))
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
dim
,
1
,
1
)
)
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
var
=
torch
.
var
(
x
,
dim
=
1
,
unbiased
=
False
,
keepdim
=
True
)
mean
=
torch
.
mean
(
x
,
dim
=
1
,
keepdim
=
True
)
return
(
x
-
mean
)
/
(
var
+
self
.
eps
).
sqrt
()
*
self
.
g
+
self
.
b
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
with_conv
:
x
=
self
.
conv
(
x
)
return
x
class
PreNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
fn
):
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
with_conv
):
super
().
__init__
()
self
.
fn
=
fn
self
.
norm
=
LayerNorm
(
dim
)
self
.
with_conv
=
with_conv
if
self
.
with_conv
:
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
):
x
=
self
.
norm
(
x
)
return
self
.
fn
(
x
)
# building block modules
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
groups
=
8
):
super
().
__init__
()
self
.
proj
=
nn
.
Conv2d
(
dim
,
dim_out
,
3
,
padding
=
1
)
self
.
norm
=
nn
.
GroupNorm
(
groups
,
dim_out
)
self
.
act
=
nn
.
SiLU
()
def
forward
(
self
,
x
,
scale_shift
=
None
):
x
=
self
.
proj
(
x
)
x
=
self
.
norm
(
x
)
if
exists
(
scale_shift
):
scale
,
shift
=
scale_shift
x
=
x
*
(
scale
+
1
)
+
shift
x
=
self
.
act
(
x
)
if
self
.
with_conv
:
pad
=
(
0
,
1
,
0
,
1
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
else
:
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
)
return
x
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
dim_out
,
*
,
time_emb_dim
=
None
,
groups
=
8
):
def
__init__
(
self
,
*
,
in_channels
,
out_channels
=
None
,
conv_shortcut
=
False
,
dropout
,
temb_channels
=
512
):
super
().
__init__
()
self
.
mlp
=
nn
.
Sequential
(
nn
.
SiLU
(),
nn
.
Linear
(
time_emb_dim
,
dim_out
*
2
))
if
exists
(
time_emb_dim
)
else
None
self
.
block1
=
Block
(
dim
,
dim_out
,
groups
=
groups
)
self
.
block2
=
Block
(
dim_out
,
dim_out
,
groups
=
groups
)
self
.
res_conv
=
nn
.
Conv2d
(
dim
,
dim_out
,
1
)
if
dim
!=
dim_out
else
nn
.
Identity
()
def
forward
(
self
,
x
,
time_emb
=
None
):
scale_shift
=
None
if
exists
(
self
.
mlp
)
and
exists
(
time_emb
):
time_emb
=
self
.
mlp
(
time_emb
)
time_emb
=
rearrange
(
time_emb
,
"b c -> b c 1 1"
)
scale_shift
=
time_emb
.
chunk
(
2
,
dim
=
1
)
h
=
self
.
block1
(
x
,
scale_shift
=
scale_shift
)
h
=
self
.
block2
(
h
)
return
h
+
self
.
res_conv
(
x
)
class
LinearAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
use_conv_shortcut
=
conv_shortcut
self
.
norm1
=
Normalize
(
in_channels
)
self
.
conv1
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
temb_proj
=
torch
.
nn
.
Linear
(
temb_channels
,
out_channels
)
self
.
norm2
=
Normalize
(
out_channels
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
self
.
conv_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
else
:
self
.
nin_shortcut
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
temb
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv1
(
h
)
h
=
h
+
self
.
temb_proj
(
nonlinearity
(
temb
))[:,
:,
None
,
None
]
h
=
self
.
norm2
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
if
self
.
use_conv_shortcut
:
x
=
self
.
conv_shortcut
(
x
)
else
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
in_channels
=
in_channels
self
.
to_out
=
nn
.
Sequential
(
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
),
LayerNorm
(
dim
))
self
.
norm
=
Normalize
(
in_channels
)
self
.
q
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
k
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
v
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
proj_out
=
torch
.
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
).
chunk
(
3
,
dim
=
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
"b (h c) x y -> b h c (x y)"
,
h
=
self
.
heads
),
qkv
)
q
=
q
.
softmax
(
dim
=-
2
)
k
=
k
.
softmax
(
dim
=-
1
)
q
=
q
*
self
.
scale
context
=
torch
.
einsum
(
"b h d n, b h e n -> b h d e"
,
k
,
v
)
h_
=
x
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
out
=
torch
.
einsum
(
"b h d e, b h d n -> b h e n"
,
context
,
q
)
out
=
rearrange
(
out
,
"b h c (x y) -> b (h c) x y"
,
h
=
self
.
heads
,
x
=
h
,
y
=
w
)
return
self
.
to_out
(
out
)
# compute attention
b
,
c
,
h
,
w
=
q
.
shape
q
=
q
.
reshape
(
b
,
c
,
h
*
w
)
q
=
q
.
permute
(
0
,
2
,
1
)
# b,hw,c
k
=
k
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
w_
=
torch
.
bmm
(
q
,
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_
=
w_
*
(
int
(
c
)
**
(
-
0.5
))
w_
=
torch
.
nn
.
functional
.
softmax
(
w_
,
dim
=
2
)
# attend to values
v
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w_
=
w_
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
h_
=
torch
.
bmm
(
v
,
w_
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
heads
=
4
,
dim_head
=
32
):
super
().
__init__
()
self
.
scale
=
dim_head
**-
0.5
self
.
heads
=
heads
hidden_dim
=
dim_head
*
heads
self
.
to_qkv
=
nn
.
Conv2d
(
dim
,
hidden_dim
*
3
,
1
,
bias
=
False
)
self
.
to_out
=
nn
.
Conv2d
(
hidden_dim
,
dim
,
1
)
def
forward
(
self
,
x
):
b
,
c
,
h
,
w
=
x
.
shape
qkv
=
self
.
to_qkv
(
x
).
chunk
(
3
,
dim
=
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
"b (h c) x y -> b h c (x y)"
,
h
=
self
.
heads
),
qkv
)
q
=
q
*
self
.
scale
sim
=
einsum
(
"b h d i, b h d j -> b h i j"
,
q
,
k
)
sim
=
sim
-
sim
.
amax
(
dim
=-
1
,
keepdim
=
True
).
detach
()
attn
=
sim
.
softmax
(
dim
=-
1
)
h_
=
self
.
proj_out
(
h_
)
out
=
einsum
(
"b h i j, b h d j -> b h i d"
,
attn
,
v
)
out
=
rearrange
(
out
,
"b h (x y) d -> b (h d) x y"
,
x
=
h
,
y
=
w
)
return
self
.
to_out
(
out
)
return
x
+
h_
class
UNetModel
(
PreTrainedModel
,
Config
):
def
__init__
(
self
,
dim
=
64
,
dim_mults
=
(
1
,
2
,
4
,
8
),
init_dim
=
None
,
out_dim
=
None
,
channels
=
3
,
with_time_emb
=
True
,
resnet_block_groups
=
8
,
learned_variance
=
False
,
ch
=
128
,
out_ch
=
3
,
ch_mult
=
(
1
,
1
,
2
,
2
,
4
,
4
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
=
3
,
resolution
=
256
,
):
super
().
__init__
()
self
.
register
(
dim
=
dim
,
dim_mults
=
dim_mults
,
init_dim
=
init_dim
,
out_dim
=
out_dim
,
channels
=
channels
,
with_time_emb
=
with_time_emb
,
resnet_block_groups
=
resnet_block_groups
,
learned_variance
=
learned_variance
,
ch
=
ch
,
out_ch
=
out_ch
,
ch_mult
=
ch_mult
,
num_res_blocks
=
num_res_blocks
,
attn_resolutions
=
attn_resolutions
,
dropout
=
dropout
,
resamp_with_conv
=
resamp_with_conv
,
in_channels
=
in_channels
,
resolution
=
resolution
,
)
ch_mult
=
tuple
(
ch_mult
)
self
.
ch
=
ch
self
.
temb_ch
=
self
.
ch
*
4
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# timestep embedding
self
.
temb
=
nn
.
Module
()
self
.
temb
.
dense
=
nn
.
ModuleList
(
[
torch
.
nn
.
Linear
(
self
.
ch
,
self
.
temb_ch
),
torch
.
nn
.
Linear
(
self
.
temb_ch
,
self
.
temb_ch
),
]
)
init_dim
=
None
out_dim
=
None
channels
=
3
with_time_emb
=
True
resnet_block_groups
=
8
learned_variance
=
False
# determine dimensions
dim_mults
=
dim_mults
dim
=
dim
self
.
channels
=
channels
init_dim
=
default
(
init_dim
,
dim
//
3
*
2
)
self
.
init_conv
=
nn
.
Conv2d
(
channels
,
init_dim
,
7
,
padding
=
3
)
dims
=
[
init_dim
,
*
map
(
lambda
m
:
dim
*
m
,
dim_mults
)]
in_out
=
list
(
zip
(
dims
[:
-
1
],
dims
[
1
:]))
block_klass
=
partial
(
ResnetBlock
,
groups
=
resnet_block_groups
)
# time embeddings
if
with_time_emb
:
time_dim
=
dim
*
4
self
.
time_mlp
=
nn
.
Sequential
(
SinusoidalPosEmb
(
dim
),
nn
.
Linear
(
dim
,
time_dim
),
nn
.
GELU
(),
nn
.
Linear
(
time_dim
,
time_dim
)
)
else
:
time_dim
=
None
self
.
time_mlp
=
None
# layers
self
.
downs
=
nn
.
ModuleList
([])
self
.
ups
=
nn
.
ModuleList
([])
num_resolutions
=
len
(
in_out
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
in_out
):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
downs
.
append
(
nn
.
ModuleList
(
[
block_klass
(
dim_in
,
dim_out
,
time_emb_dim
=
time_dim
),
block_klass
(
dim_out
,
dim_out
,
time_emb_dim
=
time_dim
),
Residual
(
PreNorm
(
dim_out
,
LinearAttention
(
dim_out
))),
Downsample
(
dim_out
)
if
not
is_last
else
nn
.
Identity
(),
]
)
)
mid_dim
=
dims
[
-
1
]
self
.
mid_block1
=
block_klass
(
mid_dim
,
mid_dim
,
time_emb_dim
=
time_dim
)
self
.
mid_attn
=
Residual
(
PreNorm
(
mid_dim
,
Attention
(
mid_dim
)))
self
.
mid_block2
=
block_klass
(
mid_dim
,
mid_dim
,
time_emb_dim
=
time_dim
)
for
ind
,
(
dim_in
,
dim_out
)
in
enumerate
(
reversed
(
in_out
[
1
:])):
is_last
=
ind
>=
(
num_resolutions
-
1
)
self
.
ups
.
append
(
nn
.
ModuleList
(
[
block_klass
(
dim_out
*
2
,
dim_in
,
time_emb_dim
=
time_dim
),
block_klass
(
dim_in
,
dim_in
,
time_emb_dim
=
time_dim
),
Residual
(
PreNorm
(
dim_in
,
LinearAttention
(
dim_in
))),
Upsample
(
dim_in
)
if
not
is_last
else
nn
.
Identity
(),
]
# downsampling
self
.
conv_in
=
torch
.
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
ch_mult
self
.
down
=
nn
.
ModuleList
()
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
)
)
default_out_dim
=
channels
*
(
1
if
not
learned_variance
else
2
)
self
.
out_dim
=
default
(
out_dim
,
default_out_dim
)
self
.
final_conv
=
nn
.
Sequential
(
block_klass
(
dim
,
dim
),
nn
.
Conv2d
(
dim
,
self
.
out_dim
,
1
))
def
forward
(
self
,
x
,
time
):
x
=
self
.
init_conv
(
x
)
t
=
self
.
time_mlp
(
time
)
if
exists
(
self
.
time_mlp
)
else
None
h
=
[]
for
block1
,
block2
,
attn
,
downsample
in
self
.
downs
:
x
=
block1
(
x
,
t
)
x
=
block2
(
x
,
t
)
x
=
attn
(
x
)
h
.
append
(
x
)
x
=
downsample
(
x
)
x
=
self
.
mid_block1
(
x
,
t
)
x
=
self
.
mid_attn
(
x
)
x
=
self
.
mid_block2
(
x
,
t
)
for
block1
,
block2
,
attn
,
upsample
in
self
.
ups
:
x
=
torch
.
cat
((
x
,
h
.
pop
()),
dim
=
1
)
x
=
block1
(
x
,
t
)
x
=
block2
(
x
,
t
)
x
=
attn
(
x
)
x
=
upsample
(
x
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
return
self
.
final_conv
(
x
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
skip_in
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
if
i_block
==
self
.
num_res_blocks
:
skip_in
=
ch
*
in_ch_mult
[
i_level
]
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
+
skip_in
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
,
)
)
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
attn
.
append
(
AttnBlock
(
block_in
))
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
,
resamp_with_conv
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
torch
.
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
,
t
):
assert
x
.
shape
[
2
]
==
x
.
shape
[
3
]
==
self
.
resolution
if
not
torch
.
is_tensor
(
t
):
t
=
torch
.
tensor
([
t
],
dtype
=
torch
.
long
,
device
=
x
.
device
)
# timestep embedding
temb
=
get_timestep_embedding
(
t
,
self
.
ch
)
temb
=
self
.
temb
.
dense
[
0
](
temb
)
temb
=
nonlinearity
(
temb
)
temb
=
self
.
temb
.
dense
[
1
](
temb
)
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
],
temb
)
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
),
temb
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
# dataset classes
class
Dataset
(
data
.
Dataset
):
def
__init__
(
self
,
folder
,
image_size
,
exts
=
[
"
jpg
"
,
"
jpeg
"
,
"
png
"
]):
def
__init__
(
self
,
folder
,
image_size
,
exts
=
[
'
jpg
'
,
'
jpeg
'
,
'
png
'
]):
super
().
__init__
()
self
.
folder
=
folder
self
.
image_size
=
image_size
self
.
paths
=
[
p
for
ext
in
exts
for
p
in
Path
(
f
"
{
folder
}
"
).
glob
(
f
"
**/*.
{
ext
}
"
)]
self
.
paths
=
[
p
for
ext
in
exts
for
p
in
Path
(
f
'
{
folder
}
'
).
glob
(
f
'
**/*.
{
ext
}
'
)]
self
.
transform
=
transforms
.
Compose
(
[
transforms
.
Resize
(
image_size
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
CenterCrop
(
image_size
),
transforms
.
ToTensor
(),
]
)
self
.
transform
=
transforms
.
Compose
([
transforms
.
Resize
(
image_size
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
CenterCrop
(
image_size
),
transforms
.
ToTensor
()
])
def
__len__
(
self
):
return
len
(
self
.
paths
)
...
...
@@ -398,10 +359,38 @@ class Dataset(data.Dataset):
# trainer class
class
EMA
():
def
__init__
(
self
,
beta
):
super
().
__init__
()
self
.
beta
=
beta
def
update_model_average
(
self
,
ma_model
,
current_model
):
for
current_params
,
ma_params
in
zip
(
current_model
.
parameters
(),
ma_model
.
parameters
()):
old_weight
,
up_weight
=
ma_params
.
data
,
current_params
.
data
ma_params
.
data
=
self
.
update_average
(
old_weight
,
up_weight
)
def
update_average
(
self
,
old
,
new
):
if
old
is
None
:
return
new
return
old
*
self
.
beta
+
(
1
-
self
.
beta
)
*
new
class
Trainer
(
object
):
def
cycle
(
dl
):
while
True
:
for
data_dl
in
dl
:
yield
data_dl
def
num_to_groups
(
num
,
divisor
):
groups
=
num
//
divisor
remainder
=
num
%
divisor
arr
=
[
divisor
]
*
groups
if
remainder
>
0
:
arr
.
append
(
remainder
)
return
arr
class
Trainer
(
object
):
def
__init__
(
self
,
diffusion_model
,
...
...
src/diffusers/pipeline_utils.py
View file @
fe313730
...
...
@@ -14,15 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib
import
os
from
typing
import
Optional
,
Union
import
importlib
from
.configuration_utils
import
Config
# CHANGE to diffusers.utils
from
transformers.utils
import
logging
from
.configuration_utils
import
Config
INDEX_FILE
=
"diffusion_model.pt"
...
...
@@ -33,7 +33,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES
=
{
"diffusers"
:
{
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"GaussianD
iffusion
"
:
[
"save_config"
,
"from_config"
],
"GaussianD
DPMScheduler
"
:
[
"save_config"
,
"from_config"
],
},
"transformers"
:
{
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
...
...
src/diffusers/samplers/gaussian.py
deleted
100644 → 0
View file @
3a5c65d5
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
inspect
import
isfunction
from
tqdm
import
tqdm
from
..configuration_utils
import
Config
SAMPLING_CONFIG_NAME
=
"sampler_config.json"
def
exists
(
x
):
return
x
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
cycle
(
dl
):
while
True
:
for
data_dl
in
dl
:
yield
data_dl
def
num_to_groups
(
num
,
divisor
):
groups
=
num
//
divisor
remainder
=
num
%
divisor
arr
=
[
divisor
]
*
groups
if
remainder
>
0
:
arr
.
append
(
remainder
)
return
arr
def
normalize_to_neg_one_to_one
(
img
):
return
img
*
2
-
1
def
unnormalize_to_zero_to_one
(
t
):
return
(
t
+
1
)
*
0.5
# small helper modules
class
EMA
:
def
__init__
(
self
,
beta
):
super
().
__init__
()
self
.
beta
=
beta
def
update_model_average
(
self
,
ma_model
,
current_model
):
for
current_params
,
ma_params
in
zip
(
current_model
.
parameters
(),
ma_model
.
parameters
()):
old_weight
,
up_weight
=
ma_params
.
data
,
current_params
.
data
ma_params
.
data
=
self
.
update_average
(
old_weight
,
up_weight
)
def
update_average
(
self
,
old
,
new
):
if
old
is
None
:
return
new
return
old
*
self
.
beta
+
(
1
-
self
.
beta
)
*
new
# gaussian diffusion trainer class
def
extract
(
a
,
t
,
x_shape
):
b
,
*
_
=
t
.
shape
out
=
a
.
gather
(
-
1
,
t
)
return
out
.
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_shape
)
-
1
)))
def
noise_like
(
shape
,
device
,
repeat
=
False
):
def
repeat_noise
():
return
torch
.
randn
((
1
,
*
shape
[
1
:]),
device
=
device
).
repeat
(
shape
[
0
],
*
((
1
,)
*
(
len
(
shape
)
-
1
)))
def
noise
():
return
torch
.
randn
(
shape
,
device
=
device
)
return
repeat_noise
()
if
repeat
else
noise
()
def
linear_beta_schedule
(
timesteps
):
scale
=
1000
/
timesteps
beta_start
=
scale
*
0.0001
beta_end
=
scale
*
0.02
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
def
cosine_beta_schedule
(
timesteps
,
s
=
0.008
):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps
=
timesteps
+
1
x
=
torch
.
linspace
(
0
,
timesteps
,
steps
,
dtype
=
torch
.
float64
)
alphas_cumprod
=
torch
.
cos
(((
x
/
timesteps
)
+
s
)
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
alphas_cumprod
=
alphas_cumprod
/
alphas_cumprod
[
0
]
betas
=
1
-
(
alphas_cumprod
[
1
:]
/
alphas_cumprod
[:
-
1
])
return
torch
.
clip
(
betas
,
0
,
0.999
)
class
GaussianDiffusion
(
nn
.
Module
,
Config
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
image_size
,
channels
=
3
,
timesteps
=
1000
,
loss_type
=
"l1"
,
objective
=
"pred_noise"
,
beta_schedule
=
"cosine"
,
):
super
().
__init__
()
self
.
register
(
image_size
=
image_size
,
channels
=
channels
,
timesteps
=
timesteps
,
loss_type
=
loss_type
,
objective
=
objective
,
beta_schedule
=
beta_schedule
,
)
self
.
channels
=
channels
self
.
image_size
=
image_size
self
.
objective
=
objective
if
beta_schedule
==
"linear"
:
betas
=
linear_beta_schedule
(
timesteps
)
elif
beta_schedule
==
"cosine"
:
betas
=
cosine_beta_schedule
(
timesteps
)
else
:
raise
ValueError
(
f
"unknown beta schedule
{
beta_schedule
}
"
)
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
F
.
pad
(
alphas_cumprod
[:
-
1
],
(
1
,
0
),
value
=
1.0
)
(
timesteps
,)
=
betas
.
shape
self
.
num_timesteps
=
int
(
timesteps
)
self
.
loss_type
=
loss_type
# helper function to register buffer from float64 to float32
def
register_buffer
(
name
,
val
):
self
.
register_buffer
(
name
,
val
.
to
(
torch
.
float32
))
register_buffer
(
"betas"
,
betas
)
register_buffer
(
"alphas_cumprod"
,
alphas_cumprod
)
register_buffer
(
"alphas_cumprod_prev"
,
alphas_cumprod_prev
)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer
(
"sqrt_alphas_cumprod"
,
torch
.
sqrt
(
alphas_cumprod
))
register_buffer
(
"sqrt_one_minus_alphas_cumprod"
,
torch
.
sqrt
(
1.0
-
alphas_cumprod
))
register_buffer
(
"log_one_minus_alphas_cumprod"
,
torch
.
log
(
1.0
-
alphas_cumprod
))
register_buffer
(
"sqrt_recip_alphas_cumprod"
,
torch
.
sqrt
(
1.0
/
alphas_cumprod
))
register_buffer
(
"sqrt_recipm1_alphas_cumprod"
,
torch
.
sqrt
(
1.0
/
alphas_cumprod
-
1
))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance
=
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer
(
"posterior_variance"
,
posterior_variance
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer
(
"posterior_log_variance_clipped"
,
torch
.
log
(
posterior_variance
.
clamp
(
min
=
1e-20
)))
register_buffer
(
"posterior_mean_coef1"
,
betas
*
torch
.
sqrt
(
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
))
register_buffer
(
"posterior_mean_coef2"
,
(
1.0
-
alphas_cumprod_prev
)
*
torch
.
sqrt
(
alphas
)
/
(
1.0
-
alphas_cumprod
)
)
def
predict_start_from_noise
(
self
,
x_t
,
t
,
noise
):
return
(
extract
(
self
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
extract
(
self
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
noise
)
def
q_posterior
(
self
,
x_start
,
x_t
,
t
):
posterior_mean
=
(
extract
(
self
.
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
x_start
+
extract
(
self
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
)
posterior_variance
=
extract
(
self
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_log_variance_clipped
=
extract
(
self
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
return
posterior_mean
,
posterior_variance
,
posterior_log_variance_clipped
def
p_mean_variance
(
self
,
model
,
x
,
t
,
clip_denoised
:
bool
):
model_output
=
model
(
x
,
t
)
if
self
.
objective
==
"pred_noise"
:
x_start
=
self
.
predict_start_from_noise
(
x
,
t
=
t
,
noise
=
model_output
)
elif
self
.
objective
==
"pred_x0"
:
x_start
=
model_output
else
:
raise
ValueError
(
f
"unknown objective
{
self
.
objective
}
"
)
if
clip_denoised
:
x_start
.
clamp_
(
-
1.0
,
1.0
)
model_mean
,
posterior_variance
,
posterior_log_variance
=
self
.
q_posterior
(
x_start
=
x_start
,
x_t
=
x
,
t
=
t
)
return
model_mean
,
posterior_variance
,
posterior_log_variance
@
torch
.
no_grad
()
def
p_sample
(
self
,
model
,
x
,
t
,
noise
=
None
,
clip_denoised
=
True
,
repeat_noise
=
False
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
model_mean
,
_
,
model_log_variance
=
self
.
p_mean_variance
(
model
=
model
,
x
=
x
,
t
=
t
,
clip_denoised
=
clip_denoised
)
if
noise
is
None
:
noise
=
noise_like
(
x
.
shape
,
device
,
repeat_noise
)
# no noise when t == 0
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x
.
shape
)
-
1
)))
result
=
model_mean
+
nonzero_mask
*
(
0.5
*
model_log_variance
).
exp
()
*
noise
return
result
@
torch
.
no_grad
()
def
p_sample_loop
(
self
,
model
,
shape
):
device
=
self
.
betas
.
device
b
=
shape
[
0
]
img
=
torch
.
randn
(
shape
,
device
=
device
)
for
i
in
tqdm
(
reversed
(
range
(
0
,
self
.
num_timesteps
)),
desc
=
"sampling loop time step"
,
total
=
self
.
num_timesteps
):
img
=
self
.
p_sample
(
model
,
img
,
torch
.
full
((
b
,),
i
,
device
=
device
,
dtype
=
torch
.
long
))
img
=
unnormalize_to_zero_to_one
(
img
)
return
img
@
torch
.
no_grad
()
def
sample
(
self
,
model
,
batch_size
=
16
):
image_size
=
self
.
image_size
channels
=
self
.
channels
return
self
.
p_sample_loop
(
model
,
(
batch_size
,
channels
,
image_size
,
image_size
))
@
torch
.
no_grad
()
def
interpolate
(
self
,
model
,
x1
,
x2
,
t
=
None
,
lam
=
0.5
):
b
,
*
_
,
device
=
*
x1
.
shape
,
x1
.
device
t
=
default
(
t
,
self
.
num_timesteps
-
1
)
assert
x1
.
shape
==
x2
.
shape
t_batched
=
torch
.
stack
([
torch
.
tensor
(
t
,
device
=
device
)]
*
b
)
xt1
,
xt2
=
map
(
lambda
x
:
self
.
q_sample
(
x
,
t
=
t_batched
),
(
x1
,
x2
))
img
=
(
1
-
lam
)
*
xt1
+
lam
*
xt2
for
i
in
tqdm
(
reversed
(
range
(
0
,
t
)),
desc
=
"interpolation sample time step"
,
total
=
t
):
img
=
self
.
p_sample
(
model
,
img
,
torch
.
full
((
b
,),
i
,
device
=
device
,
dtype
=
torch
.
long
))
return
img
def
q_sample
(
self
,
x_start
,
t
,
noise
=
None
):
noise
=
default
(
noise
,
lambda
:
torch
.
randn_like
(
x_start
))
return
(
extract
(
self
.
sqrt_alphas_cumprod
,
t
,
x_start
.
shape
)
*
x_start
+
extract
(
self
.
sqrt_one_minus_alphas_cumprod
,
t
,
x_start
.
shape
)
*
noise
)
@
property
def
loss_fn
(
self
):
if
self
.
loss_type
==
"l1"
:
return
F
.
l1_loss
elif
self
.
loss_type
==
"l2"
:
return
F
.
mse_loss
else
:
raise
ValueError
(
f
"invalid loss type
{
self
.
loss_type
}
"
)
def
p_losses
(
self
,
model
,
x_start
,
t
,
noise
=
None
):
b
,
c
,
h
,
w
=
x_start
.
shape
noise
=
default
(
noise
,
lambda
:
torch
.
randn_like
(
x_start
))
x
=
self
.
q_sample
(
x_start
=
x_start
,
t
=
t
,
noise
=
noise
)
model_out
=
model
(
x
,
t
)
if
self
.
objective
==
"pred_noise"
:
target
=
noise
elif
self
.
objective
==
"pred_x0"
:
target
=
x_start
else
:
raise
ValueError
(
f
"unknown objective
{
self
.
objective
}
"
)
loss
=
self
.
loss_fn
(
model_out
,
target
)
return
loss
def
forward
(
self
,
model
,
img
,
*
args
,
**
kwargs
):
b
,
_
,
h
,
w
,
device
,
img_size
,
=
(
*
img
.
shape
,
img
.
device
,
self
.
image_size
,
)
assert
h
==
img_size
and
w
==
img_size
,
f
"height and width of image must be
{
img_size
}
"
t
=
torch
.
randint
(
0
,
self
.
num_timesteps
,
(
b
,),
device
=
device
).
long
()
img
=
normalize_to_neg_one_to_one
(
img
)
return
self
.
p_losses
(
model
,
img
,
t
,
*
args
,
**
kwargs
)
src/diffusers/s
amp
lers/__init__.py
→
src/diffusers/s
chedu
lers/__init__.py
View file @
fe313730
...
...
@@ -16,4 +16,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.gaussian
import
GaussianD
iffusion
from
.gaussian
_ddpm
import
GaussianD
DPMScheduler
src/diffusers/schedulers/gaussian_ddpm.py
0 → 100644
View file @
fe313730
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
torch
import
nn
from
..configuration_utils
import
Config
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
def
linear_beta_schedule
(
timesteps
,
beta_start
,
beta_end
):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
class
GaussianDDPMScheduler
(
nn
.
Module
,
Config
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
timesteps
=
1000
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
variance_type
=
"fixed_small"
,
):
super
().
__init__
()
self
.
register
(
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
variance_type
=
variance_type
,
)
self
.
num_timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"linear"
:
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
torch
.
nn
.
functional
.
pad
(
alphas_cumprod
[:
-
1
],
(
1
,
0
),
value
=
1.0
)
variance
=
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
if
variance_type
==
"fixed_small"
:
log_variance
=
torch
.
log
(
variance
.
clamp
(
min
=
1e-20
))
elif
variance_type
==
"fixed_large"
:
log_variance
=
torch
.
log
(
torch
.
cat
([
variance
[
1
:
2
],
betas
[
1
:]],
dim
=
0
))
self
.
register_buffer
(
"betas"
,
betas
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"alphas"
,
alphas
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"alphas_cumprod"
,
alphas_cumprod
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"log_variance"
,
log_variance
.
to
(
torch
.
float32
))
def
get_alpha
(
self
,
time_step
):
return
self
.
alphas
[
time_step
]
def
get_beta
(
self
,
time_step
):
return
self
.
betas
[
time_step
]
def
get_alpha_prod
(
self
,
time_step
):
if
time_step
<
0
:
return
torch
.
tensor
(
1.0
)
return
self
.
alphas_cumprod
[
time_step
]
def
sample_variance
(
self
,
time_step
,
shape
,
device
,
generator
=
None
):
variance
=
self
.
log_variance
[
time_step
]
nonzero_mask
=
torch
.
tensor
([
1
-
(
time_step
==
0
)],
device
=
device
).
float
()[
None
,
:].
repeat
(
shape
[
0
],
1
)
noise
=
self
.
sample_noise
(
shape
,
device
=
device
,
generator
=
generator
)
sampled_variance
=
nonzero_mask
*
(
0.5
*
variance
).
exp
()
sampled_variance
=
sampled_variance
*
noise
return
sampled_variance
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
# always sample on CPU to be deterministic
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
return
self
.
num_timesteps
tests/test_modeling_utils.py
View file @
fe313730
...
...
@@ -16,13 +16,45 @@
import
random
import
tempfile
import
unittest
import
os
from
distutils.util
import
strtobool
import
torch
from
diffusers
import
GaussianD
iffusion
,
UNetModel
from
diffusers
import
GaussianD
DPMScheduler
,
UNetModel
global_rng
=
random
.
Random
()
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
def
parse_flag_from_env
(
key
,
default
=
False
):
try
:
value
=
os
.
environ
[
key
]
except
KeyError
:
# KEY isn't set, default to `default`.
_value
=
default
else
:
# KEY is set, convert it to True or False.
try
:
_value
=
strtobool
(
value
)
except
ValueError
:
# More values are supported, but let's keep the message simple.
raise
ValueError
(
f
"If set,
{
key
}
must be yes or no."
)
return
_value
_run_slow_tests
=
parse_flag_from_env
(
"RUN_SLOW"
,
default
=
False
)
def
slow
(
test_case
):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return
unittest
.
skipUnless
(
_run_slow_tests
,
"test is slow"
)(
test_case
)
def
floats_tensor
(
shape
,
scale
=
1.0
,
rng
=
None
,
name
=
None
):
...
...
@@ -54,7 +86,7 @@ class ModelTesterMixin(unittest.TestCase):
return
(
noise
,
time_step
)
def
test_from_pretrained_save_pretrained
(
self
):
model
=
UNetModel
(
dim
=
8
,
dim
_mult
s
=
(
1
,
2
),
resnet_block_groups
=
2
)
model
=
UNetModel
(
ch
=
32
,
ch
_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
3
2
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
...
...
@@ -77,30 +109,93 @@ class ModelTesterMixin(unittest.TestCase):
class
SamplerTesterMixin
(
unittest
.
TestCase
):
@
property
def
dummy_model
(
self
):
return
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
def
test_from_pretrained_save_pretrained
(
self
):
sampler
=
GaussianDiffusion
(
image_size
=
128
,
timesteps
=
3
,
loss_type
=
"l1"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
sampler
.
save_config
(
tmpdirname
)
new_sampler
=
GaussianDiffusion
.
from_config
(
tmpdirname
,
return_unused
=
False
)
model
=
self
.
dummy_model
torch
.
manual_seed
(
0
)
sampled_out
=
sampler
.
sample
(
model
,
batch_size
=
1
)
@
slow
def
test_sample
(
self
):
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
6694729458485568
)
# 1. Load models
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/ddpm-lsun-church"
)
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm-lsun-church"
).
to
(
torch_device
)
# 2. Sample gaussian noise
image
=
scheduler
.
sample_noise
((
1
,
model
.
in_channels
,
model
.
resolution
,
model
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
# 3. Denoise
for
t
in
reversed
(
range
(
len
(
scheduler
))):
# i) define coefficients for time step t
clip_image_coeff
=
1
/
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
))
clip_noise_coeff
=
torch
.
sqrt
(
1
/
scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
scheduler
.
get_alpha
(
t
))
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
clip_coeff
=
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
-
1
))
*
scheduler
.
get_beta
(
t
)
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
# ii) predict noise residual
with
torch
.
no_grad
():
noise_residual
=
model
(
image
,
t
)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean
=
clip_image_coeff
*
image
-
clip_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
prev_image
=
clip_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
prev_variance
=
scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
# Note: The better test is to simply check with the following lines of code that the image is sensible
# import PIL
# import numpy as np
# image_processed = image.cpu().permute(0, 2, 3, 1)
# image_processed = (image_processed + 1.0) * 127.5
# image_processed = image_processed.numpy().astype(np.uint8)
# image_pil = PIL.Image.fromarray(image_processed[0])
# image_pil.save("test.png")
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
(
image_slice
-
torch
.
tensor
([[
-
0.0598
,
-
0.0611
,
-
0.0506
],
[
-
0.0726
,
0.0220
,
0.0103
],
[
-
0.0723
,
-
0.1310
,
-
0.2458
]])).
abs
().
sum
()
<
1e-3
def
test_sample_fast
(
self
):
# 1. Load models
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
6694729458485568
)
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/ddpm-lsun-church"
,
timesteps
=
10
)
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm-lsun-church"
).
to
(
torch_device
)
# 2. Sample gaussian noise
torch
.
manual_seed
(
0
)
sampled_out_new
=
new_sampler
.
sample
(
model
,
batch_size
=
1
)
assert
(
sampled_out
-
sampled_out_new
).
abs
().
sum
()
<
1e-5
,
"Samplers don't give the same output"
def
test_from_pretrained_hub
(
self
):
sampler
=
GaussianDiffusion
.
from_config
(
"fusing/ddpm_dummy"
)
model
=
self
.
dummy_model
sampled_out
=
sampler
.
sample
(
model
,
batch_size
=
1
)
assert
sampled_out
is
not
None
,
"Make sure output is not None"
image
=
scheduler
.
sample_noise
((
1
,
model
.
in_channels
,
model
.
resolution
,
model
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
# 3. Denoise
for
t
in
reversed
(
range
(
len
(
scheduler
))):
# i) define coefficients for time step t
clip_image_coeff
=
1
/
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
))
clip_noise_coeff
=
torch
.
sqrt
(
1
/
scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
scheduler
.
get_alpha
(
t
))
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
clip_coeff
=
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
-
1
))
*
scheduler
.
get_beta
(
t
)
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
# ii) predict noise residual
with
torch
.
no_grad
():
noise_residual
=
model
(
image
,
t
)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean
=
clip_image_coeff
*
image
-
clip_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
prev_image
=
clip_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
prev_variance
=
scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
(
image_slice
-
torch
.
tensor
([[
0.1746
,
0.5125
,
-
0.7920
],
[
-
0.5734
,
-
0.2910
,
-
0.1984
],
[
0.4090
,
-
0.7740
,
-
0.3941
]])).
abs
().
sum
()
<
1e-3
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment