Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
fe313730
Commit
fe313730
authored
Jun 06, 2022
by
Patrick von Platen
Browse files
improve
parent
3a5c65d5
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
691 additions
and
760 deletions
+691
-760
README.md
README.md
+4
-4
examples/sample_loop.py
examples/sample_loop.py
+142
-84
models/vision/ddpm/example.py
models/vision/ddpm/example.py
+5
-3
models/vision/ddpm/modeling_ddpm.py
models/vision/ddpm/modeling_ddpm.py
+0
-1
models/vision/ddpm/run_ddpm.py
models/vision/ddpm/run_ddpm.py
+2
-2
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-3
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+11
-7
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+1
-1
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+298
-309
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+4
-4
src/diffusers/samplers/gaussian.py
src/diffusers/samplers/gaussian.py
+0
-313
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-1
src/diffusers/schedulers/gaussian_ddpm.py
src/diffusers/schedulers/gaussian_ddpm.py
+98
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+123
-28
No files found.
README.md
View file @
fe313730
...
@@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th
...
@@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th
Example:
Example:
```
python
```
python
from
diffusers
import
UNetModel
,
GaussianD
iffusion
from
diffusers
import
UNetModel
,
GaussianD
DPMScheduler
import
torch
import
torch
# 1. Load model
# 1. Load model
...
@@ -40,7 +40,7 @@ time_step = torch.tensor([10])
...
@@ -40,7 +40,7 @@ time_step = torch.tensor([10])
image
=
unet
(
dummy_noise
,
time_step
)
image
=
unet
(
dummy_noise
,
time_step
)
# 3. Load sampler
# 3. Load sampler
sampler
=
GaussianD
iffusion
.
from_config
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
DPMScheduler
.
from_config
(
"fusing/ddpm_dummy"
)
# 4. Sample image from sampler passing the model
# 4. Sample image from sampler passing the model
image
=
sampler
.
sample
(
model
,
batch_size
=
1
)
image
=
sampler
.
sample
(
model
,
batch_size
=
1
)
...
@@ -54,12 +54,12 @@ print(image)
...
@@ -54,12 +54,12 @@ print(image)
Example:
Example:
```
python
```
python
from
diffusers
import
UNetModel
,
GaussianD
iffusion
from
diffusers
import
UNetModel
,
GaussianD
DPMScheduler
from
modeling_ddpm
import
DDPM
from
modeling_ddpm
import
DDPM
import
tempfile
import
tempfile
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
iffusion
.
from_config
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
DPMScheduler
.
from_config
(
"fusing/ddpm_dummy"
)
# compose Diffusion Pipeline
# compose Diffusion Pipeline
ddpm
=
DDPM
(
unet
,
sampler
)
ddpm
=
DDPM
(
unet
,
sampler
)
...
...
examples/sample_loop.py
View file @
fe313730
#!/usr/bin/env python3
#!/usr/bin/env python3
from
diffusers
import
UNetModel
,
GaussianD
iffusion
from
diffusers
import
UNetModel
,
GaussianD
DPMScheduler
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
numpy
as
np
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
import
PIL.Image
diffusion
=
GaussianDiffusion
.
from_config
(
"fusing/ddpm_dummy"
)
import
tqdm
#torch_device = "cuda"
#
#unet = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church")
#unet.to(torch_device)
#
#TIME_STEPS = 10
#
#scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=TIME_STEPS)
#
#diffusion_config = {
# "beta_start": 0.0001,
# "beta_end": 0.02,
# "num_diffusion_timesteps": TIME_STEPS,
#}
#
# 2. Do one denoising step with model
# 2. Do one denoising step with model
batch_size
,
num_channels
,
height
,
width
=
1
,
3
,
32
,
32
#batch_size, num_channels, height, width = 1, 3, 256, 256
dummy_noise
=
torch
.
ones
((
batch_size
,
num_channels
,
height
,
width
))
#
#torch.manual_seed(0)
#noise_image = torch.randn(batch_size, num_channels, height, width, device="cuda")
TIME_STEPS
=
10
#
#
# Helper
# Helper
def
extract
(
a
,
t
,
x_shape
):
#def noise_like(shape, device, repeat=False):
b
,
*
_
=
t
.
shape
# def repeat_noise():
out
=
a
.
gather
(
-
1
,
t
)
# return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
return
out
.
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_shape
)
-
1
)))
#
# def noise():
# return torch.randn(shape, device=device)
#
# return repeat_noise() if repeat else noise()
#
#
#betas = np.linspace(diffusion_config["beta_start"], diffusion_config["beta_end"], diffusion_config["num_diffusion_timesteps"], dtype=np.float64)
#betas = torch.tensor(betas, device=torch_device)
#alphas = 1.0 - betas
#
#alphas_cumprod = torch.cumprod(alphas, axis=0)
#alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
#
#posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
#
#posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
#
#
#sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
#sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
#
#
#noise_coeff = (1 - alphas) / torch.sqrt(1 - alphas_cumprod)
#coeff = 1 / torch.sqrt(alphas)
def
real_fn
():
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
x_t
=
noise_image
# 2: for t = T, ...., 1 do
for
i
in
reversed
(
range
(
TIME_STEPS
)):
t
=
torch
.
tensor
([
i
]).
to
(
torch_device
)
# 3: z ~ N(0, 1)
noise
=
noise_like
(
x_t
.
shape
,
torch_device
)
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
with
torch
.
no_grad
():
pred_noise
=
unet
(
x_t
,
t
)
# pred epsilon_theta
# pred_x = sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * pred_noise
# pred_x.clamp_(-1.0, 1.0)
# pred mean
# posterior_mean = posterior_mean_coef1[t] * pred_x + posterior_mean_coef2[t] * x_t
# --------------------------------------------------------------------#
posterior_mean
=
coeff
[
t
]
*
(
x_t
-
noise_coeff
[
t
]
*
pred_noise
)
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance
=
posterior_log_variance_clipped
[
t
]
b
,
*
_
,
device
=
*
x_t
.
shape
,
x_t
.
device
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_t
.
shape
)
-
1
)))
posterior_variance
=
nonzero_mask
*
(
0.5
*
posterior_log_variance
).
exp
()
# --------------------------------------------------------------------#
x_t_1
=
(
posterior_mean
+
posterior_variance
*
noise
).
to
(
torch
.
float32
)
x_t
=
x_t_1
print
(
x_t
.
abs
().
sum
())
def
post_process_to_image
(
x_t
):
image
=
x_t
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image
=
(
image
+
1.0
)
*
127.5
image
=
image
.
numpy
().
astype
(
np
.
uint8
)
return
PIL
.
Image
.
fromarray
(
image
[
0
])
from
pytorch_diffusion
import
Diffusion
#diffusion = Diffusion.from_pretrained("lsun_church")
#samples = diffusion.denoise(1)
#
#image = post_process_to_image(samples)
#image.save("check.png")
#import ipdb; ipdb.set_trace()
device
=
"cuda"
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"/home/patrick/ddpm-lsun-church"
,
timesteps
=
10
)
import
ipdb
;
ipdb
.
set_trace
()
model
=
UNetModel
.
from_pretrained
(
"/home/patrick/ddpm-lsun-church"
).
to
(
device
)
def
noise_like
(
shape
,
device
,
repeat
=
False
):
torch
.
manual_seed
(
0
)
def
repeat_noise
():
next_image
=
scheduler
.
sample_noise
((
1
,
model
.
in_channels
,
model
.
resolution
,
model
.
resolution
),
device
=
device
)
return
torch
.
randn
((
1
,
*
shape
[
1
:]),
device
=
device
).
repeat
(
shape
[
0
],
*
((
1
,)
*
(
len
(
shape
)
-
1
)))
def
noise
():
return
torch
.
randn
(
shape
,
device
=
device
)
return
repeat_noise
()
if
repeat
else
noise
()
# Schedule
def
cosine_beta_schedule
(
timesteps
,
s
=
0.008
):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps
=
timesteps
+
1
x
=
torch
.
linspace
(
0
,
timesteps
,
steps
,
dtype
=
torch
.
float64
)
alphas_cumprod
=
torch
.
cos
(((
x
/
timesteps
)
+
s
)
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
alphas_cumprod
=
alphas_cumprod
/
alphas_cumprod
[
0
]
betas
=
1
-
(
alphas_cumprod
[
1
:]
/
alphas_cumprod
[:
-
1
])
return
torch
.
clip
(
betas
,
0
,
0.999
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
scheduler
))),
total
=
len
(
scheduler
)):
# define coefficients for time step t
clip_image_coeff
=
1
/
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
))
clip_noise_coeff
=
torch
.
sqrt
(
1
/
scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
scheduler
.
get_alpha
(
t
))
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
clip_coeff
=
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
-
1
))
*
scheduler
.
get_beta
(
t
)
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
betas
=
cosine_beta_schedule
(
TIME_STEPS
)
# predict noise residual
alphas
=
1.0
-
betas
with
torch
.
no_grad
():
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
noise_residual
=
model
(
next_image
,
t
)
alphas_cumprod_prev
=
F
.
pad
(
alphas_cumprod
[:
-
1
],
(
1
,
0
),
value
=
1.0
)
posterior_mean_coef1
=
betas
*
torch
.
sqrt
(
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
# compute prev image from noise
posterior_mean_coef2
=
(
1.0
-
alphas_cumprod_prev
)
*
torch
.
sqrt
(
alphas
)
/
(
1.0
-
alphas_cumprod
)
pred_mean
=
clip_image_coeff
*
next_image
-
clip_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
image
=
clip_coeff
*
pred_mean
+
image_coeff
*
next_image
posterior_variance
=
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
# sample variance
posterior_log_variance_clipped
=
torch
.
log
(
posterior_variance
.
clamp
(
min
=
1e-20
)
)
variance
=
scheduler
.
sample_variance
(
t
,
image
.
shape
,
device
=
device
)
# sample previous image
sampled_image
=
image
+
variance
sqrt_recip_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
)
next_image
=
sampled_image
sqrt_recipm1_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
-
1
)
torch
.
manual_seed
(
0
)
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
image
=
post_process_to_image
(
next_image
)
# 1: x_t ~ N(0,1)
image
.
save
(
"example_new.png"
)
x_t
=
dummy_noise
# 2: for t = T, ...., 1 do
for
i
in
reversed
(
range
(
TIME_STEPS
)):
t
=
torch
.
tensor
([
i
])
# 3: z ~ N(0, 1)
noise
=
noise_like
(
x_t
.
shape
,
"cpu"
)
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
pred_noise
=
unet
(
x_t
,
t
)
# pred epsilon_theta
pred_x
=
extract
(
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
extract
(
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
pred_noise
pred_x
.
clamp_
(
-
1.0
,
1.0
)
# pred mean
posterior_mean
=
extract
(
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
pred_x
+
extract
(
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
# --------------------------------------------------------------------#
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance
=
extract
(
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
b
,
*
_
,
device
=
*
x_t
.
shape
,
x_t
.
device
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_t
.
shape
)
-
1
)))
posterior_variance
=
nonzero_mask
*
(
0.5
*
posterior_log_variance
).
exp
()
# --------------------------------------------------------------------#
x_t_1
=
(
posterior_mean
+
posterior_variance
*
noise
).
to
(
torch
.
float32
)
# FOR PATRICK TO VERIFY: make sure manual loop is equal to function
# --------------------------------------------------------------------#
x_t_12
=
diffusion
.
p_sample
(
unet
,
x_t
,
t
,
noise
=
noise
)
assert
(
x_t_1
-
x_t_12
).
abs
().
sum
().
item
()
<
1e-3
# --------------------------------------------------------------------#
x_t
=
x_t_1
models/vision/ddpm/example.py
View file @
fe313730
#!/usr/bin/env python3
#!/usr/bin/env python3
from
diffusers
import
UNetModel
,
GaussianDiffusion
from
modeling_ddpm
import
DDPM
import
tempfile
import
tempfile
from
diffusers
import
GaussianDDPMScheduler
,
UNetModel
from
modeling_ddpm
import
DDPM
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
iffusion
.
from_config
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
DPMScheduler
.
from_config
(
"fusing/ddpm_dummy"
)
# compose Diffusion Pipeline
# compose Diffusion Pipeline
ddpm
=
DDPM
(
unet
,
sampler
)
ddpm
=
DDPM
(
unet
,
sampler
)
...
...
models/vision/ddpm/modeling_ddpm.py
View file @
fe313730
...
@@ -18,7 +18,6 @@ from diffusers import DiffusionPipeline
...
@@ -18,7 +18,6 @@ from diffusers import DiffusionPipeline
class
DDPM
(
DiffusionPipeline
):
class
DDPM
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
gaussian_sampler
):
def
__init__
(
self
,
unet
,
gaussian_sampler
):
super
().
__init__
(
unet
=
unet
,
gaussian_sampler
=
gaussian_sampler
)
super
().
__init__
(
unet
=
unet
,
gaussian_sampler
=
gaussian_sampler
)
...
...
models/vision/ddpm/run_ddpm.py
View file @
fe313730
#!/usr/bin/env python3
#!/usr/bin/env python3
import
torch
import
torch
from
diffusers
import
GaussianD
iffusion
,
UNetModel
from
diffusers
import
GaussianD
DPMScheduler
,
UNetModel
model
=
UNetModel
(
dim
=
64
,
dim_mults
=
(
1
,
2
,
4
,
8
))
model
=
UNetModel
(
dim
=
64
,
dim_mults
=
(
1
,
2
,
4
,
8
))
diffusion
=
GaussianD
iffusion
(
model
,
image_size
=
128
,
timesteps
=
1000
,
loss_type
=
"l1"
)
# number of steps # L1 or L2
diffusion
=
GaussianD
DPMScheduler
(
model
,
image_size
=
128
,
timesteps
=
1000
,
loss_type
=
"l1"
)
# number of steps # L1 or L2
training_images
=
torch
.
randn
(
8
,
3
,
128
,
128
)
# your images need to be normalized from a range of -1 to +1
training_images
=
torch
.
randn
(
8
,
3
,
128
,
128
)
# your images need to be normalized from a range of -1 to +1
loss
=
diffusion
(
training_images
)
loss
=
diffusion
(
training_images
)
...
...
src/diffusers/__init__.py
View file @
fe313730
...
@@ -4,8 +4,7 @@
...
@@ -4,8 +4,7 @@
__version__
=
"0.0.1"
__version__
=
"0.0.1"
from
.modeling_utils
import
PreTrainedModel
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.samplers.gaussian
import
GaussianDiffusion
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.
modeling_utils
import
PreTrainedModel
from
.
schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
src/diffusers/configuration_utils.py
View file @
fe313730
...
@@ -17,10 +17,10 @@
...
@@ -17,10 +17,10 @@
import
copy
import
copy
import
inspect
import
json
import
json
import
os
import
os
import
re
import
re
import
inspect
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
requests
import
HTTPError
from
requests
import
HTTPError
...
@@ -186,6 +186,11 @@ class Config:
...
@@ -186,6 +186,11 @@ class Config:
expected_keys
=
set
(
dict
(
inspect
.
signature
(
cls
.
__init__
).
parameters
).
keys
())
expected_keys
=
set
(
dict
(
inspect
.
signature
(
cls
.
__init__
).
parameters
).
keys
())
expected_keys
.
remove
(
"self"
)
expected_keys
.
remove
(
"self"
)
for
key
in
expected_keys
:
if
key
in
kwargs
:
# overwrite key
config_dict
[
key
]
=
kwargs
.
pop
(
key
)
passed_keys
=
set
(
config_dict
.
keys
())
passed_keys
=
set
(
config_dict
.
keys
())
unused_kwargs
=
kwargs
unused_kwargs
=
kwargs
...
@@ -194,17 +199,16 @@ class Config:
...
@@ -194,17 +199,16 @@ class Config:
if
len
(
expected_keys
-
passed_keys
)
>
0
:
if
len
(
expected_keys
-
passed_keys
)
>
0
:
logger
.
warn
(
logger
.
warn
(
f
"
{
expected_keys
-
passed_keys
}
was not found in config. "
f
"
{
expected_keys
-
passed_keys
}
was not found in config. Values will be initialized to default values."
f
"Values will be initialized to default values."
)
)
return
config_dict
,
unused_kwargs
return
config_dict
,
unused_kwargs
@
classmethod
@
classmethod
def
from_config
(
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
c
ls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
c
onfig_dict
,
unused_kwargs
=
cls
.
get_config_dict
(
):
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
config_dict
,
unused_kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
)
model
=
cls
(
**
config_dict
)
model
=
cls
(
**
config_dict
)
...
...
src/diffusers/modeling_utils.py
View file @
fe313730
...
@@ -24,6 +24,7 @@ from requests import HTTPError
...
@@ -24,6 +24,7 @@ from requests import HTTPError
# CHANGE to diffusers.utils
# CHANGE to diffusers.utils
from
transformers.utils
import
(
from
transformers.utils
import
(
CONFIG_NAME
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
EntryNotFoundError
,
EntryNotFoundError
,
RepositoryNotFoundError
,
RepositoryNotFoundError
,
...
@@ -33,7 +34,6 @@ from transformers.utils import (
...
@@ -33,7 +34,6 @@ from transformers.utils import (
is_offline_mode
,
is_offline_mode
,
is_remote_url
,
is_remote_url
,
logging
,
logging
,
CONFIG_NAME
,
)
)
...
...
src/diffusers/models/unet.py
View file @
fe313730
This diff is collapsed.
Click to expand it.
src/diffusers/pipeline_utils.py
View file @
fe313730
...
@@ -14,15 +14,15 @@
...
@@ -14,15 +14,15 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
importlib
import
os
import
os
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
import
importlib
from
.configuration_utils
import
Config
# CHANGE to diffusers.utils
# CHANGE to diffusers.utils
from
transformers.utils
import
logging
from
transformers.utils
import
logging
from
.configuration_utils
import
Config
INDEX_FILE
=
"diffusion_model.pt"
INDEX_FILE
=
"diffusion_model.pt"
...
@@ -33,7 +33,7 @@ logger = logging.get_logger(__name__)
...
@@ -33,7 +33,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES
=
{
LOADABLE_CLASSES
=
{
"diffusers"
:
{
"diffusers"
:
{
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"GaussianD
iffusion
"
:
[
"save_config"
,
"from_config"
],
"GaussianD
DPMScheduler
"
:
[
"save_config"
,
"from_config"
],
},
},
"transformers"
:
{
"transformers"
:
{
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
...
...
src/diffusers/samplers/gaussian.py
deleted
100644 → 0
View file @
3a5c65d5
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
inspect
import
isfunction
from
tqdm
import
tqdm
from
..configuration_utils
import
Config
SAMPLING_CONFIG_NAME
=
"sampler_config.json"
def
exists
(
x
):
return
x
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
cycle
(
dl
):
while
True
:
for
data_dl
in
dl
:
yield
data_dl
def
num_to_groups
(
num
,
divisor
):
groups
=
num
//
divisor
remainder
=
num
%
divisor
arr
=
[
divisor
]
*
groups
if
remainder
>
0
:
arr
.
append
(
remainder
)
return
arr
def
normalize_to_neg_one_to_one
(
img
):
return
img
*
2
-
1
def
unnormalize_to_zero_to_one
(
t
):
return
(
t
+
1
)
*
0.5
# small helper modules
class
EMA
:
def
__init__
(
self
,
beta
):
super
().
__init__
()
self
.
beta
=
beta
def
update_model_average
(
self
,
ma_model
,
current_model
):
for
current_params
,
ma_params
in
zip
(
current_model
.
parameters
(),
ma_model
.
parameters
()):
old_weight
,
up_weight
=
ma_params
.
data
,
current_params
.
data
ma_params
.
data
=
self
.
update_average
(
old_weight
,
up_weight
)
def
update_average
(
self
,
old
,
new
):
if
old
is
None
:
return
new
return
old
*
self
.
beta
+
(
1
-
self
.
beta
)
*
new
# gaussian diffusion trainer class
def
extract
(
a
,
t
,
x_shape
):
b
,
*
_
=
t
.
shape
out
=
a
.
gather
(
-
1
,
t
)
return
out
.
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_shape
)
-
1
)))
def
noise_like
(
shape
,
device
,
repeat
=
False
):
def
repeat_noise
():
return
torch
.
randn
((
1
,
*
shape
[
1
:]),
device
=
device
).
repeat
(
shape
[
0
],
*
((
1
,)
*
(
len
(
shape
)
-
1
)))
def
noise
():
return
torch
.
randn
(
shape
,
device
=
device
)
return
repeat_noise
()
if
repeat
else
noise
()
def
linear_beta_schedule
(
timesteps
):
scale
=
1000
/
timesteps
beta_start
=
scale
*
0.0001
beta_end
=
scale
*
0.02
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
def
cosine_beta_schedule
(
timesteps
,
s
=
0.008
):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps
=
timesteps
+
1
x
=
torch
.
linspace
(
0
,
timesteps
,
steps
,
dtype
=
torch
.
float64
)
alphas_cumprod
=
torch
.
cos
(((
x
/
timesteps
)
+
s
)
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
alphas_cumprod
=
alphas_cumprod
/
alphas_cumprod
[
0
]
betas
=
1
-
(
alphas_cumprod
[
1
:]
/
alphas_cumprod
[:
-
1
])
return
torch
.
clip
(
betas
,
0
,
0.999
)
class
GaussianDiffusion
(
nn
.
Module
,
Config
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
image_size
,
channels
=
3
,
timesteps
=
1000
,
loss_type
=
"l1"
,
objective
=
"pred_noise"
,
beta_schedule
=
"cosine"
,
):
super
().
__init__
()
self
.
register
(
image_size
=
image_size
,
channels
=
channels
,
timesteps
=
timesteps
,
loss_type
=
loss_type
,
objective
=
objective
,
beta_schedule
=
beta_schedule
,
)
self
.
channels
=
channels
self
.
image_size
=
image_size
self
.
objective
=
objective
if
beta_schedule
==
"linear"
:
betas
=
linear_beta_schedule
(
timesteps
)
elif
beta_schedule
==
"cosine"
:
betas
=
cosine_beta_schedule
(
timesteps
)
else
:
raise
ValueError
(
f
"unknown beta schedule
{
beta_schedule
}
"
)
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
F
.
pad
(
alphas_cumprod
[:
-
1
],
(
1
,
0
),
value
=
1.0
)
(
timesteps
,)
=
betas
.
shape
self
.
num_timesteps
=
int
(
timesteps
)
self
.
loss_type
=
loss_type
# helper function to register buffer from float64 to float32
def
register_buffer
(
name
,
val
):
self
.
register_buffer
(
name
,
val
.
to
(
torch
.
float32
))
register_buffer
(
"betas"
,
betas
)
register_buffer
(
"alphas_cumprod"
,
alphas_cumprod
)
register_buffer
(
"alphas_cumprod_prev"
,
alphas_cumprod_prev
)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer
(
"sqrt_alphas_cumprod"
,
torch
.
sqrt
(
alphas_cumprod
))
register_buffer
(
"sqrt_one_minus_alphas_cumprod"
,
torch
.
sqrt
(
1.0
-
alphas_cumprod
))
register_buffer
(
"log_one_minus_alphas_cumprod"
,
torch
.
log
(
1.0
-
alphas_cumprod
))
register_buffer
(
"sqrt_recip_alphas_cumprod"
,
torch
.
sqrt
(
1.0
/
alphas_cumprod
))
register_buffer
(
"sqrt_recipm1_alphas_cumprod"
,
torch
.
sqrt
(
1.0
/
alphas_cumprod
-
1
))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance
=
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer
(
"posterior_variance"
,
posterior_variance
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer
(
"posterior_log_variance_clipped"
,
torch
.
log
(
posterior_variance
.
clamp
(
min
=
1e-20
)))
register_buffer
(
"posterior_mean_coef1"
,
betas
*
torch
.
sqrt
(
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
))
register_buffer
(
"posterior_mean_coef2"
,
(
1.0
-
alphas_cumprod_prev
)
*
torch
.
sqrt
(
alphas
)
/
(
1.0
-
alphas_cumprod
)
)
def
predict_start_from_noise
(
self
,
x_t
,
t
,
noise
):
return
(
extract
(
self
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
extract
(
self
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
noise
)
def
q_posterior
(
self
,
x_start
,
x_t
,
t
):
posterior_mean
=
(
extract
(
self
.
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
x_start
+
extract
(
self
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
)
posterior_variance
=
extract
(
self
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_log_variance_clipped
=
extract
(
self
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
return
posterior_mean
,
posterior_variance
,
posterior_log_variance_clipped
def
p_mean_variance
(
self
,
model
,
x
,
t
,
clip_denoised
:
bool
):
model_output
=
model
(
x
,
t
)
if
self
.
objective
==
"pred_noise"
:
x_start
=
self
.
predict_start_from_noise
(
x
,
t
=
t
,
noise
=
model_output
)
elif
self
.
objective
==
"pred_x0"
:
x_start
=
model_output
else
:
raise
ValueError
(
f
"unknown objective
{
self
.
objective
}
"
)
if
clip_denoised
:
x_start
.
clamp_
(
-
1.0
,
1.0
)
model_mean
,
posterior_variance
,
posterior_log_variance
=
self
.
q_posterior
(
x_start
=
x_start
,
x_t
=
x
,
t
=
t
)
return
model_mean
,
posterior_variance
,
posterior_log_variance
@
torch
.
no_grad
()
def
p_sample
(
self
,
model
,
x
,
t
,
noise
=
None
,
clip_denoised
=
True
,
repeat_noise
=
False
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
model_mean
,
_
,
model_log_variance
=
self
.
p_mean_variance
(
model
=
model
,
x
=
x
,
t
=
t
,
clip_denoised
=
clip_denoised
)
if
noise
is
None
:
noise
=
noise_like
(
x
.
shape
,
device
,
repeat_noise
)
# no noise when t == 0
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x
.
shape
)
-
1
)))
result
=
model_mean
+
nonzero_mask
*
(
0.5
*
model_log_variance
).
exp
()
*
noise
return
result
@
torch
.
no_grad
()
def
p_sample_loop
(
self
,
model
,
shape
):
device
=
self
.
betas
.
device
b
=
shape
[
0
]
img
=
torch
.
randn
(
shape
,
device
=
device
)
for
i
in
tqdm
(
reversed
(
range
(
0
,
self
.
num_timesteps
)),
desc
=
"sampling loop time step"
,
total
=
self
.
num_timesteps
):
img
=
self
.
p_sample
(
model
,
img
,
torch
.
full
((
b
,),
i
,
device
=
device
,
dtype
=
torch
.
long
))
img
=
unnormalize_to_zero_to_one
(
img
)
return
img
@
torch
.
no_grad
()
def
sample
(
self
,
model
,
batch_size
=
16
):
image_size
=
self
.
image_size
channels
=
self
.
channels
return
self
.
p_sample_loop
(
model
,
(
batch_size
,
channels
,
image_size
,
image_size
))
@
torch
.
no_grad
()
def
interpolate
(
self
,
model
,
x1
,
x2
,
t
=
None
,
lam
=
0.5
):
b
,
*
_
,
device
=
*
x1
.
shape
,
x1
.
device
t
=
default
(
t
,
self
.
num_timesteps
-
1
)
assert
x1
.
shape
==
x2
.
shape
t_batched
=
torch
.
stack
([
torch
.
tensor
(
t
,
device
=
device
)]
*
b
)
xt1
,
xt2
=
map
(
lambda
x
:
self
.
q_sample
(
x
,
t
=
t_batched
),
(
x1
,
x2
))
img
=
(
1
-
lam
)
*
xt1
+
lam
*
xt2
for
i
in
tqdm
(
reversed
(
range
(
0
,
t
)),
desc
=
"interpolation sample time step"
,
total
=
t
):
img
=
self
.
p_sample
(
model
,
img
,
torch
.
full
((
b
,),
i
,
device
=
device
,
dtype
=
torch
.
long
))
return
img
def
q_sample
(
self
,
x_start
,
t
,
noise
=
None
):
noise
=
default
(
noise
,
lambda
:
torch
.
randn_like
(
x_start
))
return
(
extract
(
self
.
sqrt_alphas_cumprod
,
t
,
x_start
.
shape
)
*
x_start
+
extract
(
self
.
sqrt_one_minus_alphas_cumprod
,
t
,
x_start
.
shape
)
*
noise
)
@
property
def
loss_fn
(
self
):
if
self
.
loss_type
==
"l1"
:
return
F
.
l1_loss
elif
self
.
loss_type
==
"l2"
:
return
F
.
mse_loss
else
:
raise
ValueError
(
f
"invalid loss type
{
self
.
loss_type
}
"
)
def
p_losses
(
self
,
model
,
x_start
,
t
,
noise
=
None
):
b
,
c
,
h
,
w
=
x_start
.
shape
noise
=
default
(
noise
,
lambda
:
torch
.
randn_like
(
x_start
))
x
=
self
.
q_sample
(
x_start
=
x_start
,
t
=
t
,
noise
=
noise
)
model_out
=
model
(
x
,
t
)
if
self
.
objective
==
"pred_noise"
:
target
=
noise
elif
self
.
objective
==
"pred_x0"
:
target
=
x_start
else
:
raise
ValueError
(
f
"unknown objective
{
self
.
objective
}
"
)
loss
=
self
.
loss_fn
(
model_out
,
target
)
return
loss
def
forward
(
self
,
model
,
img
,
*
args
,
**
kwargs
):
b
,
_
,
h
,
w
,
device
,
img_size
,
=
(
*
img
.
shape
,
img
.
device
,
self
.
image_size
,
)
assert
h
==
img_size
and
w
==
img_size
,
f
"height and width of image must be
{
img_size
}
"
t
=
torch
.
randint
(
0
,
self
.
num_timesteps
,
(
b
,),
device
=
device
).
long
()
img
=
normalize_to_neg_one_to_one
(
img
)
return
self
.
p_losses
(
model
,
img
,
t
,
*
args
,
**
kwargs
)
src/diffusers/s
amp
lers/__init__.py
→
src/diffusers/s
chedu
lers/__init__.py
View file @
fe313730
...
@@ -16,4 +16,4 @@
...
@@ -16,4 +16,4 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.gaussian
import
GaussianD
iffusion
from
.gaussian
_ddpm
import
GaussianD
DPMScheduler
src/diffusers/schedulers/gaussian_ddpm.py
0 → 100644
View file @
fe313730
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
torch
import
nn
from
..configuration_utils
import
Config
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
def
linear_beta_schedule
(
timesteps
,
beta_start
,
beta_end
):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
class
GaussianDDPMScheduler
(
nn
.
Module
,
Config
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
timesteps
=
1000
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
variance_type
=
"fixed_small"
,
):
super
().
__init__
()
self
.
register
(
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
variance_type
=
variance_type
,
)
self
.
num_timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"linear"
:
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
torch
.
nn
.
functional
.
pad
(
alphas_cumprod
[:
-
1
],
(
1
,
0
),
value
=
1.0
)
variance
=
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
if
variance_type
==
"fixed_small"
:
log_variance
=
torch
.
log
(
variance
.
clamp
(
min
=
1e-20
))
elif
variance_type
==
"fixed_large"
:
log_variance
=
torch
.
log
(
torch
.
cat
([
variance
[
1
:
2
],
betas
[
1
:]],
dim
=
0
))
self
.
register_buffer
(
"betas"
,
betas
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"alphas"
,
alphas
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"alphas_cumprod"
,
alphas_cumprod
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"log_variance"
,
log_variance
.
to
(
torch
.
float32
))
def
get_alpha
(
self
,
time_step
):
return
self
.
alphas
[
time_step
]
def
get_beta
(
self
,
time_step
):
return
self
.
betas
[
time_step
]
def
get_alpha_prod
(
self
,
time_step
):
if
time_step
<
0
:
return
torch
.
tensor
(
1.0
)
return
self
.
alphas_cumprod
[
time_step
]
def
sample_variance
(
self
,
time_step
,
shape
,
device
,
generator
=
None
):
variance
=
self
.
log_variance
[
time_step
]
nonzero_mask
=
torch
.
tensor
([
1
-
(
time_step
==
0
)],
device
=
device
).
float
()[
None
,
:].
repeat
(
shape
[
0
],
1
)
noise
=
self
.
sample_noise
(
shape
,
device
=
device
,
generator
=
generator
)
sampled_variance
=
nonzero_mask
*
(
0.5
*
variance
).
exp
()
sampled_variance
=
sampled_variance
*
noise
return
sampled_variance
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
# always sample on CPU to be deterministic
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
return
self
.
num_timesteps
tests/test_modeling_utils.py
View file @
fe313730
...
@@ -16,13 +16,45 @@
...
@@ -16,13 +16,45 @@
import
random
import
random
import
tempfile
import
tempfile
import
unittest
import
unittest
import
os
from
distutils.util
import
strtobool
import
torch
import
torch
from
diffusers
import
GaussianD
iffusion
,
UNetModel
from
diffusers
import
GaussianD
DPMScheduler
,
UNetModel
global_rng
=
random
.
Random
()
global_rng
=
random
.
Random
()
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
def
parse_flag_from_env
(
key
,
default
=
False
):
try
:
value
=
os
.
environ
[
key
]
except
KeyError
:
# KEY isn't set, default to `default`.
_value
=
default
else
:
# KEY is set, convert it to True or False.
try
:
_value
=
strtobool
(
value
)
except
ValueError
:
# More values are supported, but let's keep the message simple.
raise
ValueError
(
f
"If set,
{
key
}
must be yes or no."
)
return
_value
_run_slow_tests
=
parse_flag_from_env
(
"RUN_SLOW"
,
default
=
False
)
def
slow
(
test_case
):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return
unittest
.
skipUnless
(
_run_slow_tests
,
"test is slow"
)(
test_case
)
def
floats_tensor
(
shape
,
scale
=
1.0
,
rng
=
None
,
name
=
None
):
def
floats_tensor
(
shape
,
scale
=
1.0
,
rng
=
None
,
name
=
None
):
...
@@ -54,7 +86,7 @@ class ModelTesterMixin(unittest.TestCase):
...
@@ -54,7 +86,7 @@ class ModelTesterMixin(unittest.TestCase):
return
(
noise
,
time_step
)
return
(
noise
,
time_step
)
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
model
=
UNetModel
(
dim
=
8
,
dim
_mult
s
=
(
1
,
2
),
resnet_block_groups
=
2
)
model
=
UNetModel
(
ch
=
32
,
ch
_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
3
2
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
model
.
save_pretrained
(
tmpdirname
)
...
@@ -77,30 +109,93 @@ class ModelTesterMixin(unittest.TestCase):
...
@@ -77,30 +109,93 @@ class ModelTesterMixin(unittest.TestCase):
class
SamplerTesterMixin
(
unittest
.
TestCase
):
class
SamplerTesterMixin
(
unittest
.
TestCase
):
@
property
@
slow
def
dummy_model
(
self
):
def
test_sample
(
self
):
return
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
6694729458485568
)
def
test_from_pretrained_save_pretrained
(
self
):
sampler
=
GaussianDiffusion
(
image_size
=
128
,
timesteps
=
3
,
loss_type
=
"l1"
)
# 1. Load models
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/ddpm-lsun-church"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm-lsun-church"
).
to
(
torch_device
)
sampler
.
save_config
(
tmpdirname
)
new_sampler
=
GaussianDiffusion
.
from_config
(
tmpdirname
,
return_unused
=
False
)
# 2. Sample gaussian noise
image
=
scheduler
.
sample_noise
((
1
,
model
.
in_channels
,
model
.
resolution
,
model
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
model
=
self
.
dummy_model
# 3. Denoise
torch
.
manual_seed
(
0
)
for
t
in
reversed
(
range
(
len
(
scheduler
))):
sampled_out
=
sampler
.
sample
(
model
,
batch_size
=
1
)
# i) define coefficients for time step t
clip_image_coeff
=
1
/
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
))
clip_noise_coeff
=
torch
.
sqrt
(
1
/
scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
scheduler
.
get_alpha
(
t
))
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
clip_coeff
=
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
-
1
))
*
scheduler
.
get_beta
(
t
)
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
# ii) predict noise residual
with
torch
.
no_grad
():
noise_residual
=
model
(
image
,
t
)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean
=
clip_image_coeff
*
image
-
clip_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
prev_image
=
clip_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
prev_variance
=
scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
# Note: The better test is to simply check with the following lines of code that the image is sensible
# import PIL
# import numpy as np
# image_processed = image.cpu().permute(0, 2, 3, 1)
# image_processed = (image_processed + 1.0) * 127.5
# image_processed = image_processed.numpy().astype(np.uint8)
# image_pil = PIL.Image.fromarray(image_processed[0])
# image_pil.save("test.png")
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
(
image_slice
-
torch
.
tensor
([[
-
0.0598
,
-
0.0611
,
-
0.0506
],
[
-
0.0726
,
0.0220
,
0.0103
],
[
-
0.0723
,
-
0.1310
,
-
0.2458
]])).
abs
().
sum
()
<
1e-3
def
test_sample_fast
(
self
):
# 1. Load models
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
6694729458485568
)
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/ddpm-lsun-church"
,
timesteps
=
10
)
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm-lsun-church"
).
to
(
torch_device
)
# 2. Sample gaussian noise
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
sampled_out_new
=
new_sampler
.
sample
(
model
,
batch_size
=
1
)
image
=
scheduler
.
sample_noise
((
1
,
model
.
in_channels
,
model
.
resolution
,
model
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
assert
(
sampled_out
-
sampled_out_new
).
abs
().
sum
()
<
1e-5
,
"Samplers don't give the same output"
# 3. Denoise
for
t
in
reversed
(
range
(
len
(
scheduler
))):
def
test_from_pretrained_hub
(
self
):
# i) define coefficients for time step t
sampler
=
GaussianDiffusion
.
from_config
(
"fusing/ddpm_dummy"
)
clip_image_coeff
=
1
/
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
))
model
=
self
.
dummy_model
clip_noise_coeff
=
torch
.
sqrt
(
1
/
scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
scheduler
.
get_alpha
(
t
))
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
sampled_out
=
sampler
.
sample
(
model
,
batch_size
=
1
)
clip_coeff
=
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
-
1
))
*
scheduler
.
get_beta
(
t
)
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
assert
sampled_out
is
not
None
,
"Make sure output is not None"
# ii) predict noise residual
with
torch
.
no_grad
():
noise_residual
=
model
(
image
,
t
)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean
=
clip_image_coeff
*
image
-
clip_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
prev_image
=
clip_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
prev_variance
=
scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
(
image_slice
-
torch
.
tensor
([[
0.1746
,
0.5125
,
-
0.7920
],
[
-
0.5734
,
-
0.2910
,
-
0.1984
],
[
0.4090
,
-
0.7740
,
-
0.3941
]])).
abs
().
sum
()
<
1e-3
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment