Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1e21f061
Commit
1e21f061
authored
Jun 08, 2022
by
anton-l
Browse files
Classifier-free guidance scheduler + GLIDe pipeline
parent
d1715d33
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
277 additions
and
50 deletions
+277
-50
models/vision/glide/README.md
models/vision/glide/README.md
+4
-0
models/vision/glide/convert_weights.py
models/vision/glide/convert_weights.py
+37
-9
models/vision/glide/modeling_glide.py
models/vision/glide/modeling_glide.py
+119
-29
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+2
-7
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-0
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-0
src/diffusers/models/clip_text_transformer.py
src/diffusers/models/clip_text_transformer.py
+0
-0
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+6
-4
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+3
-1
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-0
src/diffusers/schedulers/classifier_free_guidance.py
src/diffusers/schedulers/classifier_free_guidance.py
+102
-0
No files found.
models/vision/glide/README.md
View file @
1e21f061
# References
[
GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models
](
https://arxiv.org/pdf/2112.10741.pdf
)
[
Diffusion Models Beat GANs on Image Synthesis
](
https://arxiv.org/pdf/2105.05233.pdf
)
\ No newline at end of file
models/vision/glide/convert_weights.py
View file @
1e21f061
import
argparse
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
from
modelling_text_encoder
import
CLIPTextModel
from
diffusers
import
UNetGLIDEModel
,
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
from
modeling_glide
import
GLIDE
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict
=
torch
.
load
(
"base.pt"
,
map_location
=
"cpu"
)
state_dict
=
torch
.
load
(
"base.pt"
,
map_location
=
"cpu"
)
state_dict
=
{
k
:
nn
.
Parameter
(
v
)
for
k
,
v
in
state_dict
.
items
()}
state_dict
=
{
k
:
nn
.
Parameter
(
v
)
for
k
,
v
in
state_dict
.
items
()}
### Convert the text encoder
config
=
CLIPTextConfig
(
config
=
CLIPTextConfig
(
vocab_size
=
50257
,
max_position_embeddings
=
128
,
hidden_size
=
512
,
hidden_size
=
512
,
intermediate_size
=
2048
,
intermediate_size
=
2048
,
num_hidden_layers
=
16
,
num_hidden_layers
=
16
,
num_attention_heads
=
8
,
num_attention_heads
=
8
,
max_position_embeddings
=
128
,
use_padding_embeddings
=
True
,
use_padding_embeddings
=
True
,
)
)
model
=
CLIPTextModel
(
config
).
eval
()
model
=
CLIPTextModel
(
config
).
eval
()
tokenizer
=
GPT2Tokenizer
(
"./glide-base/vocab.json"
,
"./glide-base/merges.txt"
,
pad_token
=
"<|endoftext|>"
)
tokenizer
=
GPT2Tokenizer
(
"./glide-base/vocab.json"
,
"./glide-base/merges.txt"
,
pad_token
=
"<|endoftext|>"
)
tokenizer
.
save_pretrained
(
"./glide-base"
)
#
tokenizer.save_pretrained("./glide-base")
hf_encoder
=
model
.
text_model
hf_encoder
=
model
.
text_model
...
@@ -48,8 +51,33 @@ for layer_idx in range(config.num_hidden_layers):
...
@@ -48,8 +51,33 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer
.
mlp
.
fc2
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.weight"
]
hf_layer
.
mlp
.
fc2
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.weight"
]
hf_layer
.
mlp
.
fc2
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.bias"
]
hf_layer
.
mlp
.
fc2
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.bias"
]
inputs
=
tokenizer
([
"an oil painting of a corgi"
,
""
],
padding
=
"max_length"
,
max_length
=
128
,
return_tensors
=
"pt"
)
#inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
with
torch
.
no_grad
():
#with torch.no_grad():
outputs
=
model
(
**
inputs
)
# outputs = model(**inputs)
#model.save_pretrained("./glide-base")
### Convert the UNet
unet_model
=
UNetGLIDEModel
(
in_channels
=
3
,
model_channels
=
192
,
out_channels
=
6
,
num_res_blocks
=
3
,
attention_resolutions
=
(
2
,
4
,
8
),
dropout
=
0.1
,
channel_mult
=
(
1
,
2
,
3
,
4
),
num_heads
=
1
,
num_head_channels
=
64
,
num_heads_upsample
=
1
,
use_scale_shift_norm
=
True
,
resblock_updown
=
True
,
)
unet_model
.
load_state_dict
(
state_dict
,
strict
=
False
)
scheduler
=
ClassifierFreeGuidanceScheduler
(
timesteps
=
1000
,
beta_schedule
=
"squaredcos_cap_v2"
)
glide
=
GLIDE
(
unet
=
unet_model
,
noise_scheduler
=
scheduler
,
text_encoder
=
model
,
tokenizer
=
tokenizer
)
model
.
save_pretrained
(
"./glide-base"
)
glide
.
save_pretrained
(
"./glide-base"
)
\ No newline at end of file
\ No newline at end of file
models/vision/glide/modeling_glide.py
View file @
1e21f061
...
@@ -14,46 +14,136 @@
...
@@ -14,46 +14,136 @@
# limitations under the License.
# limitations under the License.
from
diffusers
import
DiffusionPipeline
from
diffusers
import
DiffusionPipeline
,
UNetGLIDEModel
,
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
from
diffus
ers
import
UNetGLIDEModel
from
transform
ers
import
GPT2Tokenizer
import
tqdm
import
tqdm
import
torch
import
torch
import
numpy
as
np
def
_extract_into_tensor
(
arr
,
timesteps
,
broadcast_shape
):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res
=
torch
.
from_numpy
(
arr
).
to
(
device
=
timesteps
.
device
)[
timesteps
].
float
()
while
len
(
res
.
shape
)
<
len
(
broadcast_shape
):
res
=
res
[...,
None
]
return
res
+
torch
.
zeros
(
broadcast_shape
,
device
=
timesteps
.
device
)
class
GLIDE
(
DiffusionPipeline
):
class
GLIDE
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
:
UNetGLIDEModel
,
noise_scheduler
):
def
__init__
(
self
,
unet
:
UNetGLIDEModel
,
noise_scheduler
:
ClassifierFreeGuidanceScheduler
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
GPT2Tokenizer
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
)
def
q_posterior_mean_variance
(
self
,
x_start
,
x_t
,
t
):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert
x_start
.
shape
==
x_t
.
shape
posterior_mean
=
(
_extract_into_tensor
(
self
.
noise_scheduler
.
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
x_start
+
_extract_into_tensor
(
self
.
noise_scheduler
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
)
posterior_variance
=
_extract_into_tensor
(
self
.
noise_scheduler
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_log_variance_clipped
=
_extract_into_tensor
(
self
.
noise_scheduler
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
assert
(
posterior_mean
.
shape
[
0
]
==
posterior_variance
.
shape
[
0
]
==
posterior_log_variance_clipped
.
shape
[
0
]
==
x_start
.
shape
[
0
]
)
return
posterior_mean
,
posterior_variance
,
posterior_log_variance_clipped
def
p_mean_variance
(
self
,
model
,
x
,
t
,
transformer_out
,
clip_denoised
=
True
,
model_kwargs
=
None
):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if
model_kwargs
is
None
:
model_kwargs
=
{}
def
__call__
(
self
,
generator
=
None
,
torch_device
=
None
):
B
,
C
=
x
.
shape
[:
2
]
assert
t
.
shape
==
(
B
,)
model_output
=
model
(
x
,
t
,
transformer_out
)
assert
model_output
.
shape
==
(
B
,
C
*
2
,
*
x
.
shape
[
2
:])
model_output
,
model_var_values
=
torch
.
split
(
model_output
,
C
,
dim
=
1
)
min_log
=
_extract_into_tensor
(
self
.
noise_scheduler
.
posterior_log_variance_clipped
,
t
,
x
.
shape
)
max_log
=
_extract_into_tensor
(
np
.
log
(
self
.
noise_scheduler
.
betas
),
t
,
x
.
shape
)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac
=
(
model_var_values
+
1
)
/
2
model_log_variance
=
frac
*
max_log
+
(
1
-
frac
)
*
min_log
model_variance
=
torch
.
exp
(
model_log_variance
)
pred_xstart
=
self
.
_predict_xstart_from_eps
(
x_t
=
x
,
t
=
t
,
eps
=
model_output
)
if
clip_denoised
:
pred_xstart
=
pred_xstart
.
clamp
(
-
1
,
1
)
model_mean
,
_
,
_
=
self
.
q_posterior_mean_variance
(
x_start
=
pred_xstart
,
x_t
=
x
,
t
=
t
)
assert
model_mean
.
shape
==
model_log_variance
.
shape
==
pred_xstart
.
shape
==
x
.
shape
return
model_mean
,
model_variance
,
model_log_variance
,
pred_xstart
def
_predict_xstart_from_eps
(
self
,
x_t
,
t
,
eps
):
assert
x_t
.
shape
==
eps
.
shape
return
(
_extract_into_tensor
(
self
.
noise_scheduler
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
_extract_into_tensor
(
self
.
noise_scheduler
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
eps
)
def
__call__
(
self
,
prompt
,
generator
=
None
,
torch_device
=
None
):
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
self
.
text_encoder
.
to
(
torch_device
)
# 1. Sample gaussian noise
# 1. Sample gaussian noise
image
=
self
.
noise_scheduler
.
sample_noise
((
1
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
image
=
self
.
noise_scheduler
.
sample_noise
((
1
,
self
.
unet
.
in_channels
,
64
,
64
),
device
=
torch_device
,
generator
=
generator
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
self
.
noise_scheduler
))),
total
=
len
(
self
.
noise_scheduler
)):
# i) define coefficients for time step t
# 2. Encode tokens
clip_image_coeff
=
1
/
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
# an empty input is needed to guide the model away from (
clip_noise_coeff
=
torch
.
sqrt
(
1
/
self
.
noise_scheduler
.
get_alpha_prod
(
t
)
-
1
)
inputs
=
self
.
tokenizer
([
prompt
,
""
],
padding
=
"max_length"
,
max_length
=
128
,
return_tensors
=
"pt"
)
image_coeff
=
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha
(
t
))
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
transformer_out
=
self
.
text_encoder
(
**
inputs
).
last_hidden_state
clip_coeff
=
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
self
.
noise_scheduler
.
get_beta
(
t
)
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
num_timesteps
=
len
(
self
.
noise_scheduler
)
# ii) predict noise residual
for
i
in
tqdm
.
tqdm
(
reversed
(
range
(
num_timesteps
)),
total
=
num_timesteps
):
with
torch
.
no_grad
():
t
=
torch
.
tensor
([
i
]
*
image
.
shape
[
0
],
device
=
torch_device
)
noise_residual
=
self
.
unet
(
image
,
t
)
mean
,
variance
,
log_variance
,
pred_xstart
=
self
.
p_mean_variance
(
self
.
unet
,
transformer_out
,
image
,
t
)
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
)
# iii) compute predicted image from residual
nonzero_mask
=
(
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
pred_mean
=
clip_image_coeff
*
image
-
clip_noise_coeff
*
noise_residual
)
# no noise when t == 0
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
image
=
mean
+
nonzero_mask
*
torch
.
exp
(
0.5
*
log_variance
)
*
noise
prev_image
=
clip_coeff
*
pred_mean
+
image_coeff
*
image
# iv) sample variance
prev_variance
=
self
.
noise_scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
return
image
return
image
models/vision/glide/run_glide.py
View file @
1e21f061
import
torch
import
torch
from
.modeling_glide
import
GLIDE
from
modeling_glide
import
GLIDE
from
diffusers
import
UNetGLIDEModel
,
GaussianDDPMScheduler
generator
=
torch
.
Generator
()
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
generator
=
generator
.
manual_seed
(
0
)
# 1. Load models
# 1. Load models
pipeline
=
GLIDE
.
from_pretrained
(
"fusing/glide-base"
)
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/glide-base"
)
model
=
UNetGLIDEModel
.
from_pretrained
(
"fusing/glide-base"
)
pipeline
=
GLIDE
(
model
,
scheduler
)
img
=
pipeline
(
generator
)
img
=
pipeline
(
generator
)
...
...
src/diffusers/__init__.py
View file @
1e21f061
...
@@ -7,5 +7,7 @@ __version__ = "0.0.1"
...
@@ -7,5 +7,7 @@ __version__ = "0.0.1"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
UNetGLIDEModel
from
.models.unet_glide
import
UNetGLIDEModel
from
.models.clip_text_transformer
import
CLIPTextModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
src/diffusers/models/__init__.py
View file @
1e21f061
...
@@ -18,3 +18,4 @@
...
@@ -18,3 +18,4 @@
from
.unet
import
UNetModel
from
.unet
import
UNetModel
from
.unet_glide
import
UNetGLIDEModel
from
.unet_glide
import
UNetGLIDEModel
from
.clip_text_transformer
import
CLIPTextModel
models/vision/glide/modelling_text_encod
er.py
→
src/diffusers/models/clip_text_transform
er.py
View file @
1e21f061
File moved
src/diffusers/models/unet_glide.py
View file @
1e21f061
...
@@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self
.
channel_mult
=
channel_mult
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
conv_resample
=
conv_resample
self
.
use_checkpoint
=
use_checkpoint
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
torch
.
float16
if
use_fp16
else
torch
.
float32
#
self.dtype = torch.float16 if use_fp16 else torch.float32
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
num_heads_upsample
=
num_heads_upsample
...
@@ -653,13 +653,15 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -653,13 +653,15 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
transformer_proj
=
self
.
transformer_proj
(
transformer_out
[:,
-
1
])
transformer_proj
=
self
.
transformer_proj
(
transformer_out
[:,
-
1
])
transformer_out
=
transformer_out
.
permute
(
0
,
2
,
1
)
# NLC -> NCL
transformer_out
=
transformer_out
.
permute
(
0
,
2
,
1
)
# NLC -> NCL
emb
=
emb
+
transformer_proj
.
to
(
emb
)
h
=
x
.
type
(
self
.
dtype
)
h
=
x
.
type
(
self
.
dtype
)
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
,
transformer_out
)
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
)
h
=
self
.
middle_block
(
h
,
emb
,
transformer_out
)
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
,
transformer_out
)
h
=
h
.
type
(
x
.
dtype
)
h
=
h
.
type
(
x
.
dtype
)
return
self
.
out
(
h
)
return
self
.
out
(
h
)
src/diffusers/pipeline_utils.py
View file @
1e21f061
...
@@ -35,10 +35,12 @@ logger = logging.get_logger(__name__)
...
@@ -35,10 +35,12 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES
=
{
LOADABLE_CLASSES
=
{
"diffusers"
:
{
"diffusers"
:
{
"ModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"CLIPTextModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
# TODO (Anton): move to transformers
"GaussianDDPMScheduler"
:
[
"save_config"
,
"from_config"
],
"GaussianDDPMScheduler"
:
[
"save_config"
,
"from_config"
],
"ClassifierFreeGuidanceScheduler"
:
[
"save_config"
,
"from_config"
],
},
},
"transformers"
:
{
"transformers"
:
{
"
ModelMixin
"
:
[
"save_pretrained"
,
"from_pretrained"
],
"
GPT2Tokenizer
"
:
[
"save_pretrained"
,
"from_pretrained"
],
},
},
}
}
...
...
src/diffusers/schedulers/__init__.py
View file @
1e21f061
...
@@ -17,3 +17,4 @@
...
@@ -17,3 +17,4 @@
# limitations under the License.
# limitations under the License.
from
.gaussian_ddpm
import
GaussianDDPMScheduler
from
.gaussian_ddpm
import
GaussianDDPMScheduler
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
src/diffusers/schedulers/classifier_free_guidance.py
0 → 100644
View file @
1e21f061
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
math
from
torch
import
nn
import
numpy
as
np
from
..configuration_utils
import
ConfigMixin
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
def
linear_beta_schedule
(
timesteps
,
beta_start
,
beta_end
):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
np
.
array
(
betas
,
dtype
=
np
.
float64
)
class
ClassifierFreeGuidanceScheduler
(
nn
.
Module
,
ConfigMixin
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
timesteps
=
1000
,
beta_schedule
=
"squaredcos_cap_v2"
,
):
super
().
__init__
()
self
.
register
(
timesteps
=
timesteps
,
beta_schedule
=
beta_schedule
,
)
self
.
num_timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
betas
=
betas_for_alpha_bar
(
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
alphas
=
1.0
-
betas
self
.
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
self
.
alphas_cumprod_prev
=
np
.
append
(
1.0
,
self
.
alphas_cumprod
[:
-
1
])
# calculations for diffusion q(x_t | x_{t-1}) and others
self
.
sqrt_recip_alphas_cumprod
=
np
.
sqrt
(
1.0
/
self
.
alphas_cumprod
)
self
.
sqrt_recipm1_alphas_cumprod
=
np
.
sqrt
(
1.0
/
self
.
alphas_cumprod
-
1
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self
.
posterior_variance
=
(
betas
*
(
1.0
-
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self
.
posterior_log_variance_clipped
=
np
.
log
(
np
.
append
(
self
.
posterior_variance
[
1
],
self
.
posterior_variance
[
1
:])
)
self
.
posterior_mean_coef1
=
(
betas
*
np
.
sqrt
(
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
)
self
.
posterior_mean_coef2
=
(
(
1.0
-
self
.
alphas_cumprod_prev
)
*
np
.
sqrt
(
alphas
)
/
(
1.0
-
self
.
alphas_cumprod
)
)
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
# always sample on CPU to be deterministic
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
return
self
.
num_timesteps
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment