Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
fe313730
Commit
fe313730
authored
Jun 06, 2022
by
Patrick von Platen
Browse files
improve
parent
3a5c65d5
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
691 additions
and
760 deletions
+691
-760
README.md
README.md
+4
-4
examples/sample_loop.py
examples/sample_loop.py
+142
-84
models/vision/ddpm/example.py
models/vision/ddpm/example.py
+5
-3
models/vision/ddpm/modeling_ddpm.py
models/vision/ddpm/modeling_ddpm.py
+0
-1
models/vision/ddpm/run_ddpm.py
models/vision/ddpm/run_ddpm.py
+2
-2
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-3
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+11
-7
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+1
-1
src/diffusers/models/unet.py
src/diffusers/models/unet.py
+298
-309
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+4
-4
src/diffusers/samplers/gaussian.py
src/diffusers/samplers/gaussian.py
+0
-313
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-1
src/diffusers/schedulers/gaussian_ddpm.py
src/diffusers/schedulers/gaussian_ddpm.py
+98
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+123
-28
No files found.
README.md
View file @
fe313730
...
...
@@ -27,7 +27,7 @@ One should be able to save both models and samplers as well as load them from th
Example:
```
python
from
diffusers
import
UNetModel
,
GaussianD
iffusion
from
diffusers
import
UNetModel
,
GaussianD
DPMScheduler
import
torch
# 1. Load model
...
...
@@ -40,7 +40,7 @@ time_step = torch.tensor([10])
image
=
unet
(
dummy_noise
,
time_step
)
# 3. Load sampler
sampler
=
GaussianD
iffusion
.
from_config
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
DPMScheduler
.
from_config
(
"fusing/ddpm_dummy"
)
# 4. Sample image from sampler passing the model
image
=
sampler
.
sample
(
model
,
batch_size
=
1
)
...
...
@@ -54,12 +54,12 @@ print(image)
Example:
```
python
from
diffusers
import
UNetModel
,
GaussianD
iffusion
from
diffusers
import
UNetModel
,
GaussianD
DPMScheduler
from
modeling_ddpm
import
DDPM
import
tempfile
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
iffusion
.
from_config
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
DPMScheduler
.
from_config
(
"fusing/ddpm_dummy"
)
# compose Diffusion Pipeline
ddpm
=
DDPM
(
unet
,
sampler
)
...
...
examples/sample_loop.py
View file @
fe313730
#!/usr/bin/env python3
from
diffusers
import
UNetModel
,
GaussianD
iffusion
from
diffusers
import
UNetModel
,
GaussianD
DPMScheduler
import
torch
import
torch.nn.functional
as
F
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
diffusion
=
GaussianDiffusion
.
from_config
(
"fusing/ddpm_dummy"
)
import
numpy
as
np
import
PIL.Image
import
tqdm
#torch_device = "cuda"
#
#unet = UNetModel.from_pretrained("/home/patrick/ddpm-lsun-church")
#unet.to(torch_device)
#
#TIME_STEPS = 10
#
#scheduler = GaussianDDPMScheduler.from_config("/home/patrick/ddpm-lsun-church", timesteps=TIME_STEPS)
#
#diffusion_config = {
# "beta_start": 0.0001,
# "beta_end": 0.02,
# "num_diffusion_timesteps": TIME_STEPS,
#}
#
# 2. Do one denoising step with model
batch_size
,
num_channels
,
height
,
width
=
1
,
3
,
32
,
32
dummy_noise
=
torch
.
ones
((
batch_size
,
num_channels
,
height
,
width
))
TIME_STEPS
=
10
#batch_size, num_channels, height, width = 1, 3, 256, 256
#
#torch.manual_seed(0)
#noise_image = torch.randn(batch_size, num_channels, height, width, device="cuda")
#
#
# Helper
def
extract
(
a
,
t
,
x_shape
):
b
,
*
_
=
t
.
shape
out
=
a
.
gather
(
-
1
,
t
)
return
out
.
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_shape
)
-
1
)))
#def noise_like(shape, device, repeat=False):
# def repeat_noise():
# return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
#
# def noise():
# return torch.randn(shape, device=device)
#
# return repeat_noise() if repeat else noise()
#
#
#betas = np.linspace(diffusion_config["beta_start"], diffusion_config["beta_end"], diffusion_config["num_diffusion_timesteps"], dtype=np.float64)
#betas = torch.tensor(betas, device=torch_device)
#alphas = 1.0 - betas
#
#alphas_cumprod = torch.cumprod(alphas, axis=0)
#alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
#
#posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
#
#posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
#posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
#
#
#sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
#sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
#
#
#noise_coeff = (1 - alphas) / torch.sqrt(1 - alphas_cumprod)
#coeff = 1 / torch.sqrt(alphas)
def
real_fn
():
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
x_t
=
noise_image
# 2: for t = T, ...., 1 do
for
i
in
reversed
(
range
(
TIME_STEPS
)):
t
=
torch
.
tensor
([
i
]).
to
(
torch_device
)
# 3: z ~ N(0, 1)
noise
=
noise_like
(
x_t
.
shape
,
torch_device
)
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
with
torch
.
no_grad
():
pred_noise
=
unet
(
x_t
,
t
)
# pred epsilon_theta
# pred_x = sqrt_recip_alphas_cumprod[t] * x_t - sqrt_recipm1_alphas_cumprod[t] * pred_noise
# pred_x.clamp_(-1.0, 1.0)
# pred mean
# posterior_mean = posterior_mean_coef1[t] * pred_x + posterior_mean_coef2[t] * x_t
# --------------------------------------------------------------------#
posterior_mean
=
coeff
[
t
]
*
(
x_t
-
noise_coeff
[
t
]
*
pred_noise
)
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance
=
posterior_log_variance_clipped
[
t
]
b
,
*
_
,
device
=
*
x_t
.
shape
,
x_t
.
device
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_t
.
shape
)
-
1
)))
posterior_variance
=
nonzero_mask
*
(
0.5
*
posterior_log_variance
).
exp
()
# --------------------------------------------------------------------#
x_t_1
=
(
posterior_mean
+
posterior_variance
*
noise
).
to
(
torch
.
float32
)
x_t
=
x_t_1
print
(
x_t
.
abs
().
sum
())
def
post_process_to_image
(
x_t
):
image
=
x_t
.
cpu
().
permute
(
0
,
2
,
3
,
1
)
image
=
(
image
+
1.0
)
*
127.5
image
=
image
.
numpy
().
astype
(
np
.
uint8
)
return
PIL
.
Image
.
fromarray
(
image
[
0
])
from
pytorch_diffusion
import
Diffusion
#diffusion = Diffusion.from_pretrained("lsun_church")
#samples = diffusion.denoise(1)
#
#image = post_process_to_image(samples)
#image.save("check.png")
#import ipdb; ipdb.set_trace()
device
=
"cuda"
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"/home/patrick/ddpm-lsun-church"
,
timesteps
=
10
)
import
ipdb
;
ipdb
.
set_trace
()
model
=
UNetModel
.
from_pretrained
(
"/home/patrick/ddpm-lsun-church"
).
to
(
device
)
def
noise_like
(
shape
,
device
,
repeat
=
False
):
def
repeat_noise
():
return
torch
.
randn
((
1
,
*
shape
[
1
:]),
device
=
device
).
repeat
(
shape
[
0
],
*
((
1
,)
*
(
len
(
shape
)
-
1
)))
def
noise
():
return
torch
.
randn
(
shape
,
device
=
device
)
return
repeat_noise
()
if
repeat
else
noise
()
# Schedule
def
cosine_beta_schedule
(
timesteps
,
s
=
0.008
):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps
=
timesteps
+
1
x
=
torch
.
linspace
(
0
,
timesteps
,
steps
,
dtype
=
torch
.
float64
)
alphas_cumprod
=
torch
.
cos
(((
x
/
timesteps
)
+
s
)
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
alphas_cumprod
=
alphas_cumprod
/
alphas_cumprod
[
0
]
betas
=
1
-
(
alphas_cumprod
[
1
:]
/
alphas_cumprod
[:
-
1
])
return
torch
.
clip
(
betas
,
0
,
0.999
)
torch
.
manual_seed
(
0
)
next_image
=
scheduler
.
sample_noise
((
1
,
model
.
in_channels
,
model
.
resolution
,
model
.
resolution
),
device
=
device
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
scheduler
))),
total
=
len
(
scheduler
)):
# define coefficients for time step t
clip_image_coeff
=
1
/
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
))
clip_noise_coeff
=
torch
.
sqrt
(
1
/
scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
scheduler
.
get_alpha
(
t
))
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
clip_coeff
=
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
-
1
))
*
scheduler
.
get_beta
(
t
)
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
betas
=
cosine_beta_schedule
(
TIME_STEPS
)
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
F
.
pad
(
alphas_cumprod
[:
-
1
],
(
1
,
0
),
value
=
1.0
)
# predict noise residual
with
torch
.
no_grad
():
noise_residual
=
model
(
next_image
,
t
)
posterior_mean_coef1
=
betas
*
torch
.
sqrt
(
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
posterior_mean_coef2
=
(
1.0
-
alphas_cumprod_prev
)
*
torch
.
sqrt
(
alphas
)
/
(
1.0
-
alphas_cumprod
)
# compute prev image from noise
pred_mean
=
clip_image_coeff
*
next_image
-
clip_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
image
=
clip_coeff
*
pred_mean
+
image_coeff
*
next_image
posterior_variance
=
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
posterior_log_variance_clipped
=
torch
.
log
(
posterior_variance
.
clamp
(
min
=
1e-20
)
)
# sample variance
variance
=
scheduler
.
sample_variance
(
t
,
image
.
shape
,
device
=
device
)
# sample previous image
sampled_image
=
image
+
variance
sqrt_recip_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
)
sqrt_recipm1_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
-
1
)
next_image
=
sampled_image
torch
.
manual_seed
(
0
)
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
x_t
=
dummy_noise
# 2: for t = T, ...., 1 do
for
i
in
reversed
(
range
(
TIME_STEPS
)):
t
=
torch
.
tensor
([
i
])
# 3: z ~ N(0, 1)
noise
=
noise_like
(
x_t
.
shape
,
"cpu"
)
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
pred_noise
=
unet
(
x_t
,
t
)
# pred epsilon_theta
pred_x
=
extract
(
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
extract
(
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
pred_noise
pred_x
.
clamp_
(
-
1.0
,
1.0
)
# pred mean
posterior_mean
=
extract
(
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
pred_x
+
extract
(
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
# --------------------------------------------------------------------#
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance
=
extract
(
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
b
,
*
_
,
device
=
*
x_t
.
shape
,
x_t
.
device
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_t
.
shape
)
-
1
)))
posterior_variance
=
nonzero_mask
*
(
0.5
*
posterior_log_variance
).
exp
()
# --------------------------------------------------------------------#
x_t_1
=
(
posterior_mean
+
posterior_variance
*
noise
).
to
(
torch
.
float32
)
# FOR PATRICK TO VERIFY: make sure manual loop is equal to function
# --------------------------------------------------------------------#
x_t_12
=
diffusion
.
p_sample
(
unet
,
x_t
,
t
,
noise
=
noise
)
assert
(
x_t_1
-
x_t_12
).
abs
().
sum
().
item
()
<
1e-3
# --------------------------------------------------------------------#
x_t
=
x_t_1
image
=
post_process_to_image
(
next_image
)
image
.
save
(
"example_new.png"
)
models/vision/ddpm/example.py
View file @
fe313730
#!/usr/bin/env python3
from
diffusers
import
UNetModel
,
GaussianDiffusion
from
modeling_ddpm
import
DDPM
import
tempfile
from
diffusers
import
GaussianDDPMScheduler
,
UNetModel
from
modeling_ddpm
import
DDPM
unet
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
iffusion
.
from_config
(
"fusing/ddpm_dummy"
)
sampler
=
GaussianD
DPMScheduler
.
from_config
(
"fusing/ddpm_dummy"
)
# compose Diffusion Pipeline
ddpm
=
DDPM
(
unet
,
sampler
)
...
...
models/vision/ddpm/modeling_ddpm.py
View file @
fe313730
...
...
@@ -18,7 +18,6 @@ from diffusers import DiffusionPipeline
class
DDPM
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
gaussian_sampler
):
super
().
__init__
(
unet
=
unet
,
gaussian_sampler
=
gaussian_sampler
)
...
...
models/vision/ddpm/run_ddpm.py
View file @
fe313730
#!/usr/bin/env python3
import
torch
from
diffusers
import
GaussianD
iffusion
,
UNetModel
from
diffusers
import
GaussianD
DPMScheduler
,
UNetModel
model
=
UNetModel
(
dim
=
64
,
dim_mults
=
(
1
,
2
,
4
,
8
))
diffusion
=
GaussianD
iffusion
(
model
,
image_size
=
128
,
timesteps
=
1000
,
loss_type
=
"l1"
)
# number of steps # L1 or L2
diffusion
=
GaussianD
DPMScheduler
(
model
,
image_size
=
128
,
timesteps
=
1000
,
loss_type
=
"l1"
)
# number of steps # L1 or L2
training_images
=
torch
.
randn
(
8
,
3
,
128
,
128
)
# your images need to be normalized from a range of -1 to +1
loss
=
diffusion
(
training_images
)
...
...
src/diffusers/__init__.py
View file @
fe313730
...
...
@@ -4,8 +4,7 @@
__version__
=
"0.0.1"
from
.modeling_utils
import
PreTrainedModel
from
.models.unet
import
UNetModel
from
.samplers.gaussian
import
GaussianDiffusion
from
.pipeline_utils
import
DiffusionPipeline
from
.
modeling_utils
import
PreTrainedModel
from
.
schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
src/diffusers/configuration_utils.py
View file @
fe313730
...
...
@@ -17,10 +17,10 @@
import
copy
import
inspect
import
json
import
os
import
re
import
inspect
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
requests
import
HTTPError
...
...
@@ -186,6 +186,11 @@ class Config:
expected_keys
=
set
(
dict
(
inspect
.
signature
(
cls
.
__init__
).
parameters
).
keys
())
expected_keys
.
remove
(
"self"
)
for
key
in
expected_keys
:
if
key
in
kwargs
:
# overwrite key
config_dict
[
key
]
=
kwargs
.
pop
(
key
)
passed_keys
=
set
(
config_dict
.
keys
())
unused_kwargs
=
kwargs
...
...
@@ -194,17 +199,16 @@ class Config:
if
len
(
expected_keys
-
passed_keys
)
>
0
:
logger
.
warn
(
f
"
{
expected_keys
-
passed_keys
}
was not found in config. "
f
"Values will be initialized to default values."
f
"
{
expected_keys
-
passed_keys
}
was not found in config. Values will be initialized to default values."
)
return
config_dict
,
unused_kwargs
@
classmethod
def
from_config
(
c
ls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
,
unused_kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
c
onfig_dict
,
unused_kwargs
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
model
=
cls
(
**
config_dict
)
...
...
src/diffusers/modeling_utils.py
View file @
fe313730
...
...
@@ -24,6 +24,7 @@ from requests import HTTPError
# CHANGE to diffusers.utils
from
transformers.utils
import
(
CONFIG_NAME
,
HUGGINGFACE_CO_RESOLVE_ENDPOINT
,
EntryNotFoundError
,
RepositoryNotFoundError
,
...
...
@@ -33,7 +34,6 @@ from transformers.utils import (
is_offline_mode
,
is_remote_url
,
logging
,
CONFIG_NAME
,
)
...
...
src/diffusers/models/unet.py
View file @
fe313730
This diff is collapsed.
Click to expand it.
src/diffusers/pipeline_utils.py
View file @
fe313730
...
...
@@ -14,15 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib
import
os
from
typing
import
Optional
,
Union
import
importlib
from
.configuration_utils
import
Config
# CHANGE to diffusers.utils
from
transformers.utils
import
logging
from
.configuration_utils
import
Config
INDEX_FILE
=
"diffusion_model.pt"
...
...
@@ -33,7 +33,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES
=
{
"diffusers"
:
{
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
"GaussianD
iffusion
"
:
[
"save_config"
,
"from_config"
],
"GaussianD
DPMScheduler
"
:
[
"save_config"
,
"from_config"
],
},
"transformers"
:
{
"PreTrainedModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
...
...
src/diffusers/samplers/gaussian.py
deleted
100644 → 0
View file @
3a5c65d5
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
inspect
import
isfunction
from
tqdm
import
tqdm
from
..configuration_utils
import
Config
SAMPLING_CONFIG_NAME
=
"sampler_config.json"
def
exists
(
x
):
return
x
is
not
None
def
default
(
val
,
d
):
if
exists
(
val
):
return
val
return
d
()
if
isfunction
(
d
)
else
d
def
cycle
(
dl
):
while
True
:
for
data_dl
in
dl
:
yield
data_dl
def
num_to_groups
(
num
,
divisor
):
groups
=
num
//
divisor
remainder
=
num
%
divisor
arr
=
[
divisor
]
*
groups
if
remainder
>
0
:
arr
.
append
(
remainder
)
return
arr
def
normalize_to_neg_one_to_one
(
img
):
return
img
*
2
-
1
def
unnormalize_to_zero_to_one
(
t
):
return
(
t
+
1
)
*
0.5
# small helper modules
class
EMA
:
def
__init__
(
self
,
beta
):
super
().
__init__
()
self
.
beta
=
beta
def
update_model_average
(
self
,
ma_model
,
current_model
):
for
current_params
,
ma_params
in
zip
(
current_model
.
parameters
(),
ma_model
.
parameters
()):
old_weight
,
up_weight
=
ma_params
.
data
,
current_params
.
data
ma_params
.
data
=
self
.
update_average
(
old_weight
,
up_weight
)
def
update_average
(
self
,
old
,
new
):
if
old
is
None
:
return
new
return
old
*
self
.
beta
+
(
1
-
self
.
beta
)
*
new
# gaussian diffusion trainer class
def
extract
(
a
,
t
,
x_shape
):
b
,
*
_
=
t
.
shape
out
=
a
.
gather
(
-
1
,
t
)
return
out
.
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_shape
)
-
1
)))
def
noise_like
(
shape
,
device
,
repeat
=
False
):
def
repeat_noise
():
return
torch
.
randn
((
1
,
*
shape
[
1
:]),
device
=
device
).
repeat
(
shape
[
0
],
*
((
1
,)
*
(
len
(
shape
)
-
1
)))
def
noise
():
return
torch
.
randn
(
shape
,
device
=
device
)
return
repeat_noise
()
if
repeat
else
noise
()
def
linear_beta_schedule
(
timesteps
):
scale
=
1000
/
timesteps
beta_start
=
scale
*
0.0001
beta_end
=
scale
*
0.02
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
def
cosine_beta_schedule
(
timesteps
,
s
=
0.008
):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps
=
timesteps
+
1
x
=
torch
.
linspace
(
0
,
timesteps
,
steps
,
dtype
=
torch
.
float64
)
alphas_cumprod
=
torch
.
cos
(((
x
/
timesteps
)
+
s
)
/
(
1
+
s
)
*
torch
.
pi
*
0.5
)
**
2
alphas_cumprod
=
alphas_cumprod
/
alphas_cumprod
[
0
]
betas
=
1
-
(
alphas_cumprod
[
1
:]
/
alphas_cumprod
[:
-
1
])
return
torch
.
clip
(
betas
,
0
,
0.999
)
class
GaussianDiffusion
(
nn
.
Module
,
Config
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
image_size
,
channels
=
3
,
timesteps
=
1000
,
loss_type
=
"l1"
,
objective
=
"pred_noise"
,
beta_schedule
=
"cosine"
,
):
super
().
__init__
()
self
.
register
(
image_size
=
image_size
,
channels
=
channels
,
timesteps
=
timesteps
,
loss_type
=
loss_type
,
objective
=
objective
,
beta_schedule
=
beta_schedule
,
)
self
.
channels
=
channels
self
.
image_size
=
image_size
self
.
objective
=
objective
if
beta_schedule
==
"linear"
:
betas
=
linear_beta_schedule
(
timesteps
)
elif
beta_schedule
==
"cosine"
:
betas
=
cosine_beta_schedule
(
timesteps
)
else
:
raise
ValueError
(
f
"unknown beta schedule
{
beta_schedule
}
"
)
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
F
.
pad
(
alphas_cumprod
[:
-
1
],
(
1
,
0
),
value
=
1.0
)
(
timesteps
,)
=
betas
.
shape
self
.
num_timesteps
=
int
(
timesteps
)
self
.
loss_type
=
loss_type
# helper function to register buffer from float64 to float32
def
register_buffer
(
name
,
val
):
self
.
register_buffer
(
name
,
val
.
to
(
torch
.
float32
))
register_buffer
(
"betas"
,
betas
)
register_buffer
(
"alphas_cumprod"
,
alphas_cumprod
)
register_buffer
(
"alphas_cumprod_prev"
,
alphas_cumprod_prev
)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer
(
"sqrt_alphas_cumprod"
,
torch
.
sqrt
(
alphas_cumprod
))
register_buffer
(
"sqrt_one_minus_alphas_cumprod"
,
torch
.
sqrt
(
1.0
-
alphas_cumprod
))
register_buffer
(
"log_one_minus_alphas_cumprod"
,
torch
.
log
(
1.0
-
alphas_cumprod
))
register_buffer
(
"sqrt_recip_alphas_cumprod"
,
torch
.
sqrt
(
1.0
/
alphas_cumprod
))
register_buffer
(
"sqrt_recipm1_alphas_cumprod"
,
torch
.
sqrt
(
1.0
/
alphas_cumprod
-
1
))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance
=
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer
(
"posterior_variance"
,
posterior_variance
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer
(
"posterior_log_variance_clipped"
,
torch
.
log
(
posterior_variance
.
clamp
(
min
=
1e-20
)))
register_buffer
(
"posterior_mean_coef1"
,
betas
*
torch
.
sqrt
(
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
))
register_buffer
(
"posterior_mean_coef2"
,
(
1.0
-
alphas_cumprod_prev
)
*
torch
.
sqrt
(
alphas
)
/
(
1.0
-
alphas_cumprod
)
)
def
predict_start_from_noise
(
self
,
x_t
,
t
,
noise
):
return
(
extract
(
self
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
extract
(
self
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
noise
)
def
q_posterior
(
self
,
x_start
,
x_t
,
t
):
posterior_mean
=
(
extract
(
self
.
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
x_start
+
extract
(
self
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
)
posterior_variance
=
extract
(
self
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_log_variance_clipped
=
extract
(
self
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
return
posterior_mean
,
posterior_variance
,
posterior_log_variance_clipped
def
p_mean_variance
(
self
,
model
,
x
,
t
,
clip_denoised
:
bool
):
model_output
=
model
(
x
,
t
)
if
self
.
objective
==
"pred_noise"
:
x_start
=
self
.
predict_start_from_noise
(
x
,
t
=
t
,
noise
=
model_output
)
elif
self
.
objective
==
"pred_x0"
:
x_start
=
model_output
else
:
raise
ValueError
(
f
"unknown objective
{
self
.
objective
}
"
)
if
clip_denoised
:
x_start
.
clamp_
(
-
1.0
,
1.0
)
model_mean
,
posterior_variance
,
posterior_log_variance
=
self
.
q_posterior
(
x_start
=
x_start
,
x_t
=
x
,
t
=
t
)
return
model_mean
,
posterior_variance
,
posterior_log_variance
@
torch
.
no_grad
()
def
p_sample
(
self
,
model
,
x
,
t
,
noise
=
None
,
clip_denoised
=
True
,
repeat_noise
=
False
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
model_mean
,
_
,
model_log_variance
=
self
.
p_mean_variance
(
model
=
model
,
x
=
x
,
t
=
t
,
clip_denoised
=
clip_denoised
)
if
noise
is
None
:
noise
=
noise_like
(
x
.
shape
,
device
,
repeat_noise
)
# no noise when t == 0
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x
.
shape
)
-
1
)))
result
=
model_mean
+
nonzero_mask
*
(
0.5
*
model_log_variance
).
exp
()
*
noise
return
result
@
torch
.
no_grad
()
def
p_sample_loop
(
self
,
model
,
shape
):
device
=
self
.
betas
.
device
b
=
shape
[
0
]
img
=
torch
.
randn
(
shape
,
device
=
device
)
for
i
in
tqdm
(
reversed
(
range
(
0
,
self
.
num_timesteps
)),
desc
=
"sampling loop time step"
,
total
=
self
.
num_timesteps
):
img
=
self
.
p_sample
(
model
,
img
,
torch
.
full
((
b
,),
i
,
device
=
device
,
dtype
=
torch
.
long
))
img
=
unnormalize_to_zero_to_one
(
img
)
return
img
@
torch
.
no_grad
()
def
sample
(
self
,
model
,
batch_size
=
16
):
image_size
=
self
.
image_size
channels
=
self
.
channels
return
self
.
p_sample_loop
(
model
,
(
batch_size
,
channels
,
image_size
,
image_size
))
@
torch
.
no_grad
()
def
interpolate
(
self
,
model
,
x1
,
x2
,
t
=
None
,
lam
=
0.5
):
b
,
*
_
,
device
=
*
x1
.
shape
,
x1
.
device
t
=
default
(
t
,
self
.
num_timesteps
-
1
)
assert
x1
.
shape
==
x2
.
shape
t_batched
=
torch
.
stack
([
torch
.
tensor
(
t
,
device
=
device
)]
*
b
)
xt1
,
xt2
=
map
(
lambda
x
:
self
.
q_sample
(
x
,
t
=
t_batched
),
(
x1
,
x2
))
img
=
(
1
-
lam
)
*
xt1
+
lam
*
xt2
for
i
in
tqdm
(
reversed
(
range
(
0
,
t
)),
desc
=
"interpolation sample time step"
,
total
=
t
):
img
=
self
.
p_sample
(
model
,
img
,
torch
.
full
((
b
,),
i
,
device
=
device
,
dtype
=
torch
.
long
))
return
img
def
q_sample
(
self
,
x_start
,
t
,
noise
=
None
):
noise
=
default
(
noise
,
lambda
:
torch
.
randn_like
(
x_start
))
return
(
extract
(
self
.
sqrt_alphas_cumprod
,
t
,
x_start
.
shape
)
*
x_start
+
extract
(
self
.
sqrt_one_minus_alphas_cumprod
,
t
,
x_start
.
shape
)
*
noise
)
@
property
def
loss_fn
(
self
):
if
self
.
loss_type
==
"l1"
:
return
F
.
l1_loss
elif
self
.
loss_type
==
"l2"
:
return
F
.
mse_loss
else
:
raise
ValueError
(
f
"invalid loss type
{
self
.
loss_type
}
"
)
def
p_losses
(
self
,
model
,
x_start
,
t
,
noise
=
None
):
b
,
c
,
h
,
w
=
x_start
.
shape
noise
=
default
(
noise
,
lambda
:
torch
.
randn_like
(
x_start
))
x
=
self
.
q_sample
(
x_start
=
x_start
,
t
=
t
,
noise
=
noise
)
model_out
=
model
(
x
,
t
)
if
self
.
objective
==
"pred_noise"
:
target
=
noise
elif
self
.
objective
==
"pred_x0"
:
target
=
x_start
else
:
raise
ValueError
(
f
"unknown objective
{
self
.
objective
}
"
)
loss
=
self
.
loss_fn
(
model_out
,
target
)
return
loss
def
forward
(
self
,
model
,
img
,
*
args
,
**
kwargs
):
b
,
_
,
h
,
w
,
device
,
img_size
,
=
(
*
img
.
shape
,
img
.
device
,
self
.
image_size
,
)
assert
h
==
img_size
and
w
==
img_size
,
f
"height and width of image must be
{
img_size
}
"
t
=
torch
.
randint
(
0
,
self
.
num_timesteps
,
(
b
,),
device
=
device
).
long
()
img
=
normalize_to_neg_one_to_one
(
img
)
return
self
.
p_losses
(
model
,
img
,
t
,
*
args
,
**
kwargs
)
src/diffusers/s
amp
lers/__init__.py
→
src/diffusers/s
chedu
lers/__init__.py
View file @
fe313730
...
...
@@ -16,4 +16,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.gaussian
import
GaussianD
iffusion
from
.gaussian
_ddpm
import
GaussianD
DPMScheduler
src/diffusers/schedulers/gaussian_ddpm.py
0 → 100644
View file @
fe313730
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
torch
import
nn
from
..configuration_utils
import
Config
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
def
linear_beta_schedule
(
timesteps
,
beta_start
,
beta_end
):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
class
GaussianDDPMScheduler
(
nn
.
Module
,
Config
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
timesteps
=
1000
,
beta_start
=
0.0001
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
variance_type
=
"fixed_small"
,
):
super
().
__init__
()
self
.
register
(
timesteps
=
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
,
beta_schedule
=
beta_schedule
,
variance_type
=
variance_type
,
)
self
.
num_timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"linear"
:
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
torch
.
nn
.
functional
.
pad
(
alphas_cumprod
[:
-
1
],
(
1
,
0
),
value
=
1.0
)
variance
=
betas
*
(
1.0
-
alphas_cumprod_prev
)
/
(
1.0
-
alphas_cumprod
)
if
variance_type
==
"fixed_small"
:
log_variance
=
torch
.
log
(
variance
.
clamp
(
min
=
1e-20
))
elif
variance_type
==
"fixed_large"
:
log_variance
=
torch
.
log
(
torch
.
cat
([
variance
[
1
:
2
],
betas
[
1
:]],
dim
=
0
))
self
.
register_buffer
(
"betas"
,
betas
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"alphas"
,
alphas
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"alphas_cumprod"
,
alphas_cumprod
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"log_variance"
,
log_variance
.
to
(
torch
.
float32
))
def
get_alpha
(
self
,
time_step
):
return
self
.
alphas
[
time_step
]
def
get_beta
(
self
,
time_step
):
return
self
.
betas
[
time_step
]
def
get_alpha_prod
(
self
,
time_step
):
if
time_step
<
0
:
return
torch
.
tensor
(
1.0
)
return
self
.
alphas_cumprod
[
time_step
]
def
sample_variance
(
self
,
time_step
,
shape
,
device
,
generator
=
None
):
variance
=
self
.
log_variance
[
time_step
]
nonzero_mask
=
torch
.
tensor
([
1
-
(
time_step
==
0
)],
device
=
device
).
float
()[
None
,
:].
repeat
(
shape
[
0
],
1
)
noise
=
self
.
sample_noise
(
shape
,
device
=
device
,
generator
=
generator
)
sampled_variance
=
nonzero_mask
*
(
0.5
*
variance
).
exp
()
sampled_variance
=
sampled_variance
*
noise
return
sampled_variance
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
# always sample on CPU to be deterministic
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
return
self
.
num_timesteps
tests/test_modeling_utils.py
View file @
fe313730
...
...
@@ -16,13 +16,45 @@
import
random
import
tempfile
import
unittest
import
os
from
distutils.util
import
strtobool
import
torch
from
diffusers
import
GaussianD
iffusion
,
UNetModel
from
diffusers
import
GaussianD
DPMScheduler
,
UNetModel
global_rng
=
random
.
Random
()
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
def
parse_flag_from_env
(
key
,
default
=
False
):
try
:
value
=
os
.
environ
[
key
]
except
KeyError
:
# KEY isn't set, default to `default`.
_value
=
default
else
:
# KEY is set, convert it to True or False.
try
:
_value
=
strtobool
(
value
)
except
ValueError
:
# More values are supported, but let's keep the message simple.
raise
ValueError
(
f
"If set,
{
key
}
must be yes or no."
)
return
_value
_run_slow_tests
=
parse_flag_from_env
(
"RUN_SLOW"
,
default
=
False
)
def
slow
(
test_case
):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return
unittest
.
skipUnless
(
_run_slow_tests
,
"test is slow"
)(
test_case
)
def
floats_tensor
(
shape
,
scale
=
1.0
,
rng
=
None
,
name
=
None
):
...
...
@@ -54,7 +86,7 @@ class ModelTesterMixin(unittest.TestCase):
return
(
noise
,
time_step
)
def
test_from_pretrained_save_pretrained
(
self
):
model
=
UNetModel
(
dim
=
8
,
dim
_mult
s
=
(
1
,
2
),
resnet_block_groups
=
2
)
model
=
UNetModel
(
ch
=
32
,
ch
_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
3
2
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
...
...
@@ -77,30 +109,93 @@ class ModelTesterMixin(unittest.TestCase):
class
SamplerTesterMixin
(
unittest
.
TestCase
):
@
property
def
dummy_model
(
self
):
return
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
def
test_from_pretrained_save_pretrained
(
self
):
sampler
=
GaussianDiffusion
(
image_size
=
128
,
timesteps
=
3
,
loss_type
=
"l1"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
sampler
.
save_config
(
tmpdirname
)
new_sampler
=
GaussianDiffusion
.
from_config
(
tmpdirname
,
return_unused
=
False
)
model
=
self
.
dummy_model
torch
.
manual_seed
(
0
)
sampled_out
=
sampler
.
sample
(
model
,
batch_size
=
1
)
@
slow
def
test_sample
(
self
):
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
6694729458485568
)
# 1. Load models
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/ddpm-lsun-church"
)
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm-lsun-church"
).
to
(
torch_device
)
# 2. Sample gaussian noise
image
=
scheduler
.
sample_noise
((
1
,
model
.
in_channels
,
model
.
resolution
,
model
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
# 3. Denoise
for
t
in
reversed
(
range
(
len
(
scheduler
))):
# i) define coefficients for time step t
clip_image_coeff
=
1
/
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
))
clip_noise_coeff
=
torch
.
sqrt
(
1
/
scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
scheduler
.
get_alpha
(
t
))
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
clip_coeff
=
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
-
1
))
*
scheduler
.
get_beta
(
t
)
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
# ii) predict noise residual
with
torch
.
no_grad
():
noise_residual
=
model
(
image
,
t
)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean
=
clip_image_coeff
*
image
-
clip_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
prev_image
=
clip_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
prev_variance
=
scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
# Note: The better test is to simply check with the following lines of code that the image is sensible
# import PIL
# import numpy as np
# image_processed = image.cpu().permute(0, 2, 3, 1)
# image_processed = (image_processed + 1.0) * 127.5
# image_processed = image_processed.numpy().astype(np.uint8)
# image_pil = PIL.Image.fromarray(image_processed[0])
# image_pil.save("test.png")
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
(
image_slice
-
torch
.
tensor
([[
-
0.0598
,
-
0.0611
,
-
0.0506
],
[
-
0.0726
,
0.0220
,
0.0103
],
[
-
0.0723
,
-
0.1310
,
-
0.2458
]])).
abs
().
sum
()
<
1e-3
def
test_sample_fast
(
self
):
# 1. Load models
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
6694729458485568
)
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/ddpm-lsun-church"
,
timesteps
=
10
)
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm-lsun-church"
).
to
(
torch_device
)
# 2. Sample gaussian noise
torch
.
manual_seed
(
0
)
sampled_out_new
=
new_sampler
.
sample
(
model
,
batch_size
=
1
)
assert
(
sampled_out
-
sampled_out_new
).
abs
().
sum
()
<
1e-5
,
"Samplers don't give the same output"
def
test_from_pretrained_hub
(
self
):
sampler
=
GaussianDiffusion
.
from_config
(
"fusing/ddpm_dummy"
)
model
=
self
.
dummy_model
sampled_out
=
sampler
.
sample
(
model
,
batch_size
=
1
)
assert
sampled_out
is
not
None
,
"Make sure output is not None"
image
=
scheduler
.
sample_noise
((
1
,
model
.
in_channels
,
model
.
resolution
,
model
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
# 3. Denoise
for
t
in
reversed
(
range
(
len
(
scheduler
))):
# i) define coefficients for time step t
clip_image_coeff
=
1
/
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
))
clip_noise_coeff
=
torch
.
sqrt
(
1
/
scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
scheduler
.
get_alpha
(
t
))
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
clip_coeff
=
torch
.
sqrt
(
scheduler
.
get_alpha_prod
(
t
-
1
))
*
scheduler
.
get_beta
(
t
)
/
(
1
-
scheduler
.
get_alpha_prod
(
t
))
# ii) predict noise residual
with
torch
.
no_grad
():
noise_residual
=
model
(
image
,
t
)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean
=
clip_image_coeff
*
image
-
clip_noise_coeff
*
noise_residual
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
prev_image
=
clip_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
prev_variance
=
scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
image_slice
=
image
[
0
,
-
1
,
-
3
:,
-
3
:].
cpu
()
assert
(
image_slice
-
torch
.
tensor
([[
0.1746
,
0.5125
,
-
0.7920
],
[
-
0.5734
,
-
0.2910
,
-
0.1984
],
[
0.4090
,
-
0.7740
,
-
0.3941
]])).
abs
().
sum
()
<
1e-3
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment