Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
9fdbc14e
Commit
9fdbc14e
authored
Jun 08, 2022
by
Patrick von Platen
Browse files
Merge branch 'main' of
https://github.com/huggingface/diffusers
into main
parents
ef044a72
d754ce5f
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
998 additions
and
84 deletions
+998
-84
models/vision/glide/README.md
models/vision/glide/README.md
+4
-0
models/vision/glide/convert_weights.py
models/vision/glide/convert_weights.py
+37
-19
models/vision/glide/modeling_glide.py
models/vision/glide/modeling_glide.py
+138
-30
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+5
-8
src/diffusers/__init__.py
src/diffusers/__init__.py
+3
-0
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+3
-7
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+2
-0
src/diffusers/models/clip_text_transformer.py
src/diffusers/models/clip_text_transformer.py
+685
-0
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+16
-16
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+2
-2
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+5
-2
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-0
src/diffusers/schedulers/classifier_free_guidance.py
src/diffusers/schedulers/classifier_free_guidance.py
+97
-0
No files found.
models/vision/glide/README.md
View file @
9fdbc14e
# References
[
GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models
](
https://arxiv.org/pdf/2112.10741.pdf
)
[
Diffusion Models Beat GANs on Image Synthesis
](
https://arxiv.org/pdf/2105.05233.pdf
)
\ No newline at end of file
models/vision/glide/convert_weights.py
View file @
9fdbc14e
import
argparse
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
CLIPTextConfig
,
CLIPTextModel
,
GPT2Tokenizer
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
UNetGLIDEModel
from
modeling_glide
import
GLIDE
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict
=
torch
.
load
(
"base.pt"
,
map_location
=
"cpu"
)
state_dict
=
torch
.
load
(
"base.pt"
,
map_location
=
"cpu"
)
state_dict
=
{
k
:
nn
.
Parameter
(
v
)
for
k
,
v
in
state_dict
.
items
()}
state_dict
=
{
k
:
nn
.
Parameter
(
v
)
for
k
,
v
in
state_dict
.
items
()}
### Convert the text encoder
config
=
CLIPTextConfig
(
config
=
CLIPTextConfig
(
vocab_size
=
50257
,
max_position_embeddings
=
128
,
hidden_size
=
512
,
hidden_size
=
512
,
intermediate_size
=
2048
,
intermediate_size
=
2048
,
num_hidden_layers
=
16
,
num_hidden_layers
=
16
,
num_attention_heads
=
8
,
num_attention_heads
=
8
,
max_position
_embeddings
=
128
use_padding
_embeddings
=
True
,
)
)
model
=
CLIPTextModel
(
config
).
eval
()
model
=
CLIPTextModel
(
config
).
eval
()
tokenizer
=
GPT2Tokenizer
(
"./glide-base/vocab.json"
,
"./glide-base/merges.txt"
,
pad_token
=
"<|endoftext|>"
)
tokenizer
=
GPT2Tokenizer
(
"./glide-base/tokenizer/vocab.json"
,
"./glide-base/tokenizer/merges.txt"
,
pad_token
=
"<|endoftext|>"
)
tokenizer
.
save_pretrained
(
"./glide-base"
)
hf_encoder
=
model
.
text_model
hf_encoder
=
model
.
text_model
...
@@ -30,15 +35,8 @@ hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
...
@@ -30,15 +35,8 @@ hf_encoder.final_layer_norm.bias = state_dict["final_ln.bias"]
for
layer_idx
in
range
(
config
.
num_hidden_layers
):
for
layer_idx
in
range
(
config
.
num_hidden_layers
):
hf_layer
=
hf_encoder
.
encoder
.
layers
[
layer_idx
]
hf_layer
=
hf_encoder
.
encoder
.
layers
[
layer_idx
]
q_proj
,
k_proj
,
v_proj
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.attn.c_qkv.weight"
].
chunk
(
3
,
dim
=
0
)
hf_layer
.
self_attn
.
qkv_proj
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.attn.c_qkv.weight"
]
q_proj_bias
,
k_proj_bias
,
v_proj_bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.attn.c_qkv.bias"
].
chunk
(
3
,
dim
=
0
)
hf_layer
.
self_attn
.
qkv_proj
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.attn.c_qkv.bias"
]
hf_layer
.
self_attn
.
q_proj
.
weight
.
data
=
q_proj
hf_layer
.
self_attn
.
q_proj
.
bias
.
data
=
q_proj_bias
hf_layer
.
self_attn
.
k_proj
.
weight
.
data
=
k_proj
hf_layer
.
self_attn
.
k_proj
.
bias
.
data
=
k_proj_bias
hf_layer
.
self_attn
.
v_proj
.
weight
.
data
=
v_proj
hf_layer
.
self_attn
.
v_proj
.
bias
.
data
=
v_proj_bias
hf_layer
.
self_attn
.
out_proj
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.attn.c_proj.weight"
]
hf_layer
.
self_attn
.
out_proj
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.attn.c_proj.weight"
]
hf_layer
.
self_attn
.
out_proj
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.attn.c_proj.bias"
]
hf_layer
.
self_attn
.
out_proj
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.attn.c_proj.bias"
]
...
@@ -53,8 +51,28 @@ for layer_idx in range(config.num_hidden_layers):
...
@@ -53,8 +51,28 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer
.
mlp
.
fc2
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.weight"
]
hf_layer
.
mlp
.
fc2
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.weight"
]
hf_layer
.
mlp
.
fc2
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.bias"
]
hf_layer
.
mlp
.
fc2
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.bias"
]
inputs
=
tokenizer
([
"an oil painting of a corgi"
,
""
],
padding
=
"max_length"
,
max_length
=
128
,
return_tensors
=
"pt"
)
### Convert the UNet
with
torch
.
no_grad
():
outputs
=
model
(
**
inputs
)
unet_model
=
UNetGLIDEModel
(
in_channels
=
3
,
model_channels
=
192
,
out_channels
=
6
,
num_res_blocks
=
3
,
attention_resolutions
=
(
2
,
4
,
8
),
dropout
=
0.1
,
channel_mult
=
(
1
,
2
,
3
,
4
),
num_heads
=
1
,
num_head_channels
=
64
,
num_heads_upsample
=
1
,
use_scale_shift_norm
=
True
,
resblock_updown
=
True
,
transformer_dim
=
512
,
)
unet_model
.
load_state_dict
(
state_dict
,
strict
=
False
)
scheduler
=
ClassifierFreeGuidanceScheduler
(
timesteps
=
1000
,
beta_schedule
=
"squaredcos_cap_v2"
)
glide
=
GLIDE
(
unet
=
unet_model
,
noise_scheduler
=
scheduler
,
text_encoder
=
model
,
tokenizer
=
tokenizer
)
model
.
save_pretrained
(
"./glide-base"
)
glide
.
save_pretrained
(
"./glide-base"
)
\ No newline at end of file
models/vision/glide/modeling_glide.py
View file @
9fdbc14e
...
@@ -14,46 +14,154 @@
...
@@ -14,46 +14,154 @@
# limitations under the License.
# limitations under the License.
from
diffusers
import
DiffusionPipeline
import
numpy
as
np
from
diffusers
import
UNetGLIDEModel
import
torch
import
tqdm
import
tqdm
import
torch
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
DiffusionPipeline
,
UNetGLIDEModel
from
transformers
import
GPT2Tokenizer
def
_extract_into_tensor
(
arr
,
timesteps
,
broadcast_shape
):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res
=
torch
.
from_numpy
(
arr
).
to
(
device
=
timesteps
.
device
)[
timesteps
].
float
()
while
len
(
res
.
shape
)
<
len
(
broadcast_shape
):
res
=
res
[...,
None
]
return
res
+
torch
.
zeros
(
broadcast_shape
,
device
=
timesteps
.
device
)
class
GLIDE
(
DiffusionPipeline
):
class
GLIDE
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
:
UNetGLIDEModel
,
noise_scheduler
):
def
__init__
(
self
,
unet
:
UNetGLIDEModel
,
noise_scheduler
:
ClassifierFreeGuidanceScheduler
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
GPT2Tokenizer
,
):
super
().
__init__
()
super
().
__init__
()
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
)
def
q_posterior_mean_variance
(
self
,
x_start
,
x_t
,
t
):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert
x_start
.
shape
==
x_t
.
shape
posterior_mean
=
(
_extract_into_tensor
(
self
.
noise_scheduler
.
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
x_start
+
_extract_into_tensor
(
self
.
noise_scheduler
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
)
posterior_variance
=
_extract_into_tensor
(
self
.
noise_scheduler
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_log_variance_clipped
=
_extract_into_tensor
(
self
.
noise_scheduler
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
assert
(
posterior_mean
.
shape
[
0
]
==
posterior_variance
.
shape
[
0
]
==
posterior_log_variance_clipped
.
shape
[
0
]
==
x_start
.
shape
[
0
]
)
return
posterior_mean
,
posterior_variance
,
posterior_log_variance_clipped
def
p_mean_variance
(
self
,
model
,
x
,
t
,
transformer_out
,
clip_denoised
=
True
,
model_kwargs
=
None
):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
if
model_kwargs
is
None
:
model_kwargs
=
{}
def
__call__
(
self
,
generator
=
None
,
torch_device
=
None
):
B
,
C
=
x
.
shape
[:
2
]
assert
t
.
shape
==
(
B
,)
model_output
=
model
(
x
,
t
,
transformer_out
)
assert
model_output
.
shape
==
(
B
,
C
*
2
,
*
x
.
shape
[
2
:])
model_output
,
model_var_values
=
torch
.
split
(
model_output
,
C
,
dim
=
1
)
min_log
=
_extract_into_tensor
(
self
.
noise_scheduler
.
posterior_log_variance_clipped
,
t
,
x
.
shape
)
max_log
=
_extract_into_tensor
(
np
.
log
(
self
.
noise_scheduler
.
betas
),
t
,
x
.
shape
)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac
=
(
model_var_values
+
1
)
/
2
model_log_variance
=
frac
*
max_log
+
(
1
-
frac
)
*
min_log
model_variance
=
torch
.
exp
(
model_log_variance
)
pred_xstart
=
self
.
_predict_xstart_from_eps
(
x_t
=
x
,
t
=
t
,
eps
=
model_output
)
if
clip_denoised
:
pred_xstart
=
pred_xstart
.
clamp
(
-
1
,
1
)
model_mean
,
_
,
_
=
self
.
q_posterior_mean_variance
(
x_start
=
pred_xstart
,
x_t
=
x
,
t
=
t
)
assert
model_mean
.
shape
==
model_log_variance
.
shape
==
pred_xstart
.
shape
==
x
.
shape
return
model_mean
,
model_variance
,
model_log_variance
,
pred_xstart
def
_predict_xstart_from_eps
(
self
,
x_t
,
t
,
eps
):
assert
x_t
.
shape
==
eps
.
shape
return
(
_extract_into_tensor
(
self
.
noise_scheduler
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
_extract_into_tensor
(
self
.
noise_scheduler
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
eps
)
def
__call__
(
self
,
prompt
,
generator
=
None
,
torch_device
=
None
):
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
self
.
text_encoder
.
to
(
torch_device
)
# Create a classifier-free guidance sampling function
guidance_scale
=
3.0
def
model_fn
(
x_t
,
ts
,
transformer_out
,
**
kwargs
):
half
=
x_t
[:
len
(
x_t
)
//
2
]
combined
=
torch
.
cat
([
half
,
half
],
dim
=
0
)
model_out
=
self
.
unet
(
combined
,
ts
,
transformer_out
,
**
kwargs
)
eps
,
rest
=
model_out
[:,
:
3
],
model_out
[:,
3
:]
cond_eps
,
uncond_eps
=
torch
.
split
(
eps
,
len
(
eps
)
//
2
,
dim
=
0
)
half_eps
=
uncond_eps
+
guidance_scale
*
(
cond_eps
-
uncond_eps
)
eps
=
torch
.
cat
([
half_eps
,
half_eps
],
dim
=
0
)
return
torch
.
cat
([
eps
,
rest
],
dim
=
1
)
# 1. Sample gaussian noise
# 1. Sample gaussian noise
image
=
self
.
noise_scheduler
.
sample_noise
((
1
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
)
batch_size
=
2
# second image is empty for classifier-free guidance
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
len
(
self
.
noise_scheduler
))),
total
=
len
(
self
.
noise_scheduler
)):
image
=
self
.
noise_scheduler
.
sample_noise
(
# i) define coefficients for time step t
(
batch_size
,
self
.
unet
.
in_channels
,
64
,
64
),
device
=
torch_device
,
generator
=
generator
clip_image_coeff
=
1
/
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
)
clip_noise_coeff
=
torch
.
sqrt
(
1
/
self
.
noise_scheduler
.
get_alpha_prod
(
t
)
-
1
)
image_coeff
=
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha
(
t
))
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
# 2. Encode tokens
clip_coeff
=
torch
.
sqrt
(
self
.
noise_scheduler
.
get_alpha_prod
(
t
-
1
))
*
self
.
noise_scheduler
.
get_beta
(
t
)
/
(
1
-
self
.
noise_scheduler
.
get_alpha_prod
(
t
))
# an empty input is needed to guide the model away from (
inputs
=
self
.
tokenizer
([
prompt
,
""
],
padding
=
"max_length"
,
max_length
=
128
,
return_tensors
=
"pt"
)
# ii) predict noise residual
input_ids
=
inputs
[
"input_ids"
].
to
(
torch_device
)
with
torch
.
no_grad
():
attention_mask
=
inputs
[
"attention_mask"
].
to
(
torch_device
)
noise_residual
=
self
.
unet
(
image
,
t
)
transformer_out
=
self
.
text_encoder
(
input_ids
,
attention_mask
).
last_hidden_state
# iii) compute predicted image from residual
num_timesteps
=
len
(
self
.
noise_scheduler
)
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
for
i
in
tqdm
.
tqdm
(
reversed
(
range
(
num_timesteps
)),
total
=
num_timesteps
):
pred_mean
=
clip_image_coeff
*
image
-
clip_noise_coeff
*
noise_residual
t
=
torch
.
tensor
([
i
]
*
image
.
shape
[
0
],
device
=
torch_device
)
pred_mean
=
torch
.
clamp
(
pred_mean
,
-
1
,
1
)
mean
,
variance
,
log_variance
,
pred_xstart
=
self
.
p_mean_variance
(
model_fn
,
image
,
t
,
transformer_out
)
prev_image
=
clip_coeff
*
pred_mean
+
image_coeff
*
image
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
nonzero_mask
=
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
# no noise when t == 0
# iv) sample variance
image
=
mean
+
nonzero_mask
*
torch
.
exp
(
0.5
*
log_variance
)
*
noise
prev_variance
=
self
.
noise_scheduler
.
sample_variance
(
t
,
prev_image
.
shape
,
device
=
torch_device
,
generator
=
generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image
=
prev_image
+
prev_variance
image
=
sampled_prev_image
return
image
return
image
models/vision/glide/run_glide.py
View file @
9fdbc14e
import
torch
import
torch
from
.modeling_glide
import
GLIDE
from
diffusers
import
UNetGLIDEModel
,
GaussianDDPMScheduler
from
modeling_glide
import
GLIDE
generator
=
torch
.
Generator
()
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
generator
=
generator
.
manual_seed
(
0
)
# 1. Load models
# 1. Load models
pipeline
=
GLIDE
.
from_pretrained
(
"fusing/glide-base"
)
scheduler
=
GaussianDDPMScheduler
.
from_config
(
"fusing/glide-base"
)
img
=
pipeline
(
"an oil painting of a corgi"
,
generator
)
model
=
UNetGLIDEModel
.
from_pretrained
(
"fusing/glide-base"
)
pipeline
=
GLIDE
(
model
,
scheduler
)
img
=
pipeline
(
generator
)
print
(
img
)
print
(
img
)
src/diffusers/__init__.py
View file @
9fdbc14e
...
@@ -5,7 +5,10 @@
...
@@ -5,7 +5,10 @@
__version__
=
"0.0.1"
__version__
=
"0.0.1"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models.clip_text_transformer
import
CLIPTextModel
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
UNetGLIDEModel
from
.models.unet_glide
import
UNetGLIDEModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
src/diffusers/configuration_utils.py
View file @
9fdbc14e
...
@@ -90,7 +90,6 @@ class ConfigMixin:
...
@@ -90,7 +90,6 @@ class ConfigMixin:
self
.
to_json_file
(
output_config_file
)
self
.
to_json_file
(
output_config_file
)
logger
.
info
(
f
"ConfigMixinuration saved in
{
output_config_file
}
"
)
logger
.
info
(
f
"ConfigMixinuration saved in
{
output_config_file
}
"
)
@
classmethod
@
classmethod
def
get_config_dict
(
def
get_config_dict
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
**
kwargs
...
@@ -199,7 +198,6 @@ class ConfigMixin:
...
@@ -199,7 +198,6 @@ class ConfigMixin:
# use value from config dict
# use value from config dict
init_dict
[
key
]
=
config_dict
.
pop
(
key
)
init_dict
[
key
]
=
config_dict
.
pop
(
key
)
unused_kwargs
=
config_dict
.
update
(
kwargs
)
unused_kwargs
=
config_dict
.
update
(
kwargs
)
passed_keys
=
set
(
init_dict
.
keys
())
passed_keys
=
set
(
init_dict
.
keys
())
...
@@ -212,9 +210,7 @@ class ConfigMixin:
...
@@ -212,9 +210,7 @@ class ConfigMixin:
@
classmethod
@
classmethod
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
=
cls
.
get_config_dict
(
config_dict
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
...
...
src/diffusers/models/__init__.py
View file @
9fdbc14e
...
@@ -16,5 +16,7 @@
...
@@ -16,5 +16,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.clip_text_transformer
import
CLIPTextModel
from
.unet
import
UNetModel
from
.unet
import
UNetModel
from
.unet_glide
import
UNetGLIDEModel
from
.unet_glide
import
UNetGLIDEModel
from
.unet_ldm
import
UNetLDMModel
src/diffusers/models/clip_text_transformer.py
0 → 100644
View file @
9fdbc14e
# coding=utf-8
# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch CLIP model."""
import
math
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
,
Tuple
,
Union
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
transformers
import
CLIPConfig
,
CLIPModel
,
CLIPTextConfig
,
CLIPVisionConfig
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
(
ModelOutput
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
,
)
logger
=
logging
.
get_logger
(
__name__
)
_CHECKPOINT_FOR_DOC
=
"openai/clip-vit-base-patch32"
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST
=
[
"openai/clip-vit-base-patch32"
,
# See all CLIP models at https://huggingface.co/models?filter=clip
]
# Copied from transformers.models.bart.modeling_bart._expand_mask
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz
,
src_len
=
mask
.
size
()
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
expanded_mask
=
mask
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
tgt_len
,
src_len
).
to
(
dtype
)
inverted_mask
=
1.0
-
expanded_mask
return
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
# contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
def
contrastive_loss
(
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
nn
.
functional
.
cross_entropy
(
logits
,
torch
.
arange
(
len
(
logits
),
device
=
logits
.
device
))
def
clip_loss
(
similarity
:
torch
.
Tensor
)
->
torch
.
Tensor
:
caption_loss
=
contrastive_loss
(
similarity
)
image_loss
=
contrastive_loss
(
similarity
.
T
)
return
(
caption_loss
+
image_loss
)
/
2.0
@
dataclass
class
CLIPOutput
(
ModelOutput
):
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
text_model_output(`BaseModelOutputWithPooling`):
The output of the [`CLIPTextModel`].
vision_model_output(`BaseModelOutputWithPooling`):
The output of the [`CLIPVisionModel`].
"""
loss
:
Optional
[
torch
.
FloatTensor
]
=
None
logits_per_image
:
torch
.
FloatTensor
=
None
logits_per_text
:
torch
.
FloatTensor
=
None
text_embeds
:
torch
.
FloatTensor
=
None
image_embeds
:
torch
.
FloatTensor
=
None
text_model_output
:
BaseModelOutputWithPooling
=
None
vision_model_output
:
BaseModelOutputWithPooling
=
None
def
to_tuple
(
self
)
->
Tuple
[
Any
]:
return
tuple
(
self
[
k
]
if
k
not
in
[
"text_model_output"
,
"vision_model_output"
]
else
getattr
(
self
,
k
).
to_tuple
()
for
k
in
self
.
keys
()
)
class
CLIPVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
class_embedding
=
nn
.
Parameter
(
torch
.
randn
(
self
.
embed_dim
))
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
3
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
bias
=
False
)
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
self
.
num_positions
=
self
.
num_patches
+
1
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_positions
,
self
.
embed_dim
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
self
.
num_positions
).
expand
((
1
,
-
1
)))
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
patch_embeds
=
self
.
patch_embedding
(
pixel_values
)
# shape = [*, width, grid, grid]
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
class
CLIPTextEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPTextConfig
):
super
().
__init__
()
embed_dim
=
config
.
hidden_size
self
.
token_embedding
=
nn
.
Embedding
(
config
.
vocab_size
,
embed_dim
)
self
.
position_embedding
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
embed_dim
)
self
.
use_padding_embeddings
=
config
.
use_padding_embeddings
if
self
.
use_padding_embeddings
:
self
.
padding_embedding
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
embed_dim
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
seq_length
=
input_ids
.
shape
[
-
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
-
2
]
if
position_ids
is
None
:
position_ids
=
self
.
position_ids
[:,
:
seq_length
]
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
token_embedding
(
input_ids
)
position_embeddings
=
self
.
position_embedding
(
position_ids
)
embeddings
=
inputs_embeds
+
position_embeddings
if
self
.
use_padding_embeddings
and
attention_mask
is
not
None
:
padding_embeddings
=
self
.
padding_embedding
(
position_ids
)
embeddings
=
torch
.
where
(
attention_mask
.
bool
().
unsqueeze
(
-
1
),
embeddings
,
padding_embeddings
)
return
embeddings
class
CLIPAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
if
self
.
head_dim
*
self
.
num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads (got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:"
f
"
{
self
.
num_heads
}
)."
)
self
.
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
self
.
head_dim
))
self
.
qkv_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
*
3
)
self
.
out_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
causal_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Input shape: Batch x Time x Channel"""
bsz
,
tgt_len
,
embed_dim
=
hidden_states
.
size
()
qkv_states
=
self
.
qkv_proj
(
hidden_states
)
qkv_states
=
qkv_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads
,
-
1
)
query_states
,
key_states
,
value_states
=
torch
.
split
(
qkv_states
,
self
.
head_dim
,
dim
=-
1
)
attn_weights
=
torch
.
einsum
(
"bthc,bshc->bhts"
,
query_states
*
self
.
scale
,
key_states
*
self
.
scale
)
wdtype
=
attn_weights
.
dtype
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
.
float
(),
dim
=-
1
).
type
(
wdtype
)
attn_output
=
torch
.
einsum
(
"bhts,bshc->bthc"
,
attn_weights
,
value_states
)
attn_output
=
attn_output
.
reshape
(
bsz
,
tgt_len
,
-
1
)
attn_output
=
self
.
out_proj
(
attn_output
)
return
attn_output
,
attn_weights
class
CLIPMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
fc1
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
fc2
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
CLIPEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPConfig
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
self_attn
=
CLIPAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
)
self
.
mlp
=
CLIPMLP
(
config
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
causal_attention_mask
:
torch
.
Tensor
,
output_attentions
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual
=
hidden_states
hidden_states
=
self
.
layer_norm1
(
hidden_states
)
hidden_states
,
attn_weights
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
causal_attention_mask
=
causal_attention_mask
,
output_attentions
=
output_attentions
,
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
layer_norm2
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,)
if
output_attentions
:
outputs
+=
(
attn_weights
,)
return
outputs
class
CLIPPreTrainedModel
(
PreTrainedModel
):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class
=
CLIPConfig
base_model_prefix
=
"clip"
supports_gradient_checkpointing
=
True
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
factor
=
self
.
config
.
initializer_factor
if
isinstance
(
module
,
CLIPTextEmbeddings
):
module
.
token_embedding
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
0.02
)
module
.
position_embedding
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
0.02
)
if
hasattr
(
module
,
"padding_embedding"
):
module
.
padding_embedding
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
0.02
)
elif
isinstance
(
module
,
CLIPVisionEmbeddings
):
factor
=
self
.
config
.
initializer_factor
nn
.
init
.
normal_
(
module
.
class_embedding
,
mean
=
0.0
,
std
=
module
.
embed_dim
**-
0.5
*
factor
)
nn
.
init
.
normal_
(
module
.
patch_embedding
.
weight
,
std
=
module
.
config
.
initializer_range
*
factor
)
nn
.
init
.
normal_
(
module
.
position_embedding
.
weight
,
std
=
module
.
config
.
initializer_range
*
factor
)
elif
isinstance
(
module
,
CLIPAttention
):
factor
=
self
.
config
.
initializer_factor
in_proj_std
=
(
module
.
embed_dim
**-
0.5
)
*
((
2
*
module
.
config
.
num_hidden_layers
)
**
-
0.5
)
*
factor
out_proj_std
=
(
module
.
embed_dim
**-
0.5
)
*
factor
nn
.
init
.
normal_
(
module
.
qkv_proj
.
weight
,
std
=
in_proj_std
)
nn
.
init
.
normal_
(
module
.
out_proj
.
weight
,
std
=
out_proj_std
)
elif
isinstance
(
module
,
CLIPMLP
):
factor
=
self
.
config
.
initializer_factor
in_proj_std
=
(
(
module
.
config
.
hidden_size
**-
0.5
)
*
((
2
*
module
.
config
.
num_hidden_layers
)
**
-
0.5
)
*
factor
)
fc_std
=
(
2
*
module
.
config
.
hidden_size
)
**
-
0.5
*
factor
nn
.
init
.
normal_
(
module
.
fc1
.
weight
,
std
=
fc_std
)
nn
.
init
.
normal_
(
module
.
fc2
.
weight
,
std
=
in_proj_std
)
elif
isinstance
(
module
,
CLIPModel
):
nn
.
init
.
normal_
(
module
.
text_projection
.
weight
,
std
=
module
.
text_embed_dim
**-
0.5
*
self
.
config
.
initializer_factor
,
)
nn
.
init
.
normal_
(
module
.
visual_projection
.
weight
,
std
=
module
.
vision_embed_dim
**-
0.5
*
self
.
config
.
initializer_factor
,
)
if
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
CLIPEncoder
):
module
.
gradient_checkpointing
=
value
CLIP_START_DOCSTRING
=
r
"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
CLIP_TEXT_INPUTS_DOCSTRING
=
r
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
CLIP_VISION_INPUTS_DOCSTRING
=
r
"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
CLIP_INPUTS_DOCSTRING
=
r
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class
CLIPEncoder
(
nn
.
Module
):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def
__init__
(
self
,
config
:
CLIPConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
CLIPEncoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
inputs_embeds
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
causal_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutput
]:
r
"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Causal mask for the text model. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
encoder_states
=
()
if
output_hidden_states
else
None
all_attentions
=
()
if
output_attentions
else
None
hidden_states
=
inputs_embeds
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
,
output_attentions
)
return
custom_forward
layer_outputs
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
encoder_layer
),
hidden_states
,
attention_mask
,
causal_attention_mask
,
)
else
:
layer_outputs
=
encoder_layer
(
hidden_states
,
attention_mask
,
causal_attention_mask
,
output_attentions
=
output_attentions
,
)
hidden_states
=
layer_outputs
[
0
]
if
output_attentions
:
all_attentions
=
all_attentions
+
(
layer_outputs
[
1
],)
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
BaseModelOutput
(
last_hidden_state
=
hidden_states
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
)
class
CLIPTextTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPTextConfig
):
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
CLIPTextEmbeddings
(
config
)
self
.
encoder
=
CLIPEncoder
(
config
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
embed_dim
)
@
add_start_docstrings_to_model_forward
(
CLIP_TEXT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
BaseModelOutputWithPooling
,
config_class
=
CLIPTextConfig
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPooling
]:
r
"""
Returns:
"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
input_ids
is
None
:
raise
ValueError
(
"You have to specify either input_ids"
)
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
attention_mask
=
attention_mask
)
bsz
,
seq_len
=
input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask
=
self
.
_build_causal_attention_mask
(
bsz
,
seq_len
).
to
(
hidden_states
.
device
)
# expand attention_mask
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask
=
_expand_mask
(
attention_mask
,
hidden_states
.
dtype
)
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
attention_mask
=
None
,
causal_attention_mask
=
None
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
last_hidden_state
=
encoder_outputs
[
0
]
last_hidden_state
=
self
.
final_layer_norm
(
last_hidden_state
)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
pooled_output
=
last_hidden_state
[
torch
.
arange
(
last_hidden_state
.
shape
[
0
]),
input_ids
.
argmax
(
dim
=-
1
)]
if
not
return_dict
:
return
(
last_hidden_state
,
pooled_output
)
+
encoder_outputs
[
1
:]
return
BaseModelOutputWithPooling
(
last_hidden_state
=
last_hidden_state
,
pooler_output
=
pooled_output
,
hidden_states
=
encoder_outputs
.
hidden_states
,
attentions
=
encoder_outputs
.
attentions
,
)
def
_build_causal_attention_mask
(
self
,
bsz
,
seq_len
):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask
=
torch
.
empty
(
bsz
,
seq_len
,
seq_len
)
mask
.
fill_
(
torch
.
tensor
(
float
(
"-inf"
)))
mask
.
triu_
(
1
)
# zero out the lower diagonal
mask
=
mask
.
unsqueeze
(
1
)
# expand mask
return
mask
class
CLIPTextModel
(
CLIPPreTrainedModel
):
config_class
=
CLIPTextConfig
def
__init__
(
self
,
config
:
CLIPTextConfig
):
super
().
__init__
(
config
)
self
.
text_model
=
CLIPTextTransformer
(
config
)
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
text_model
.
embeddings
.
token_embedding
def
set_input_embeddings
(
self
,
value
):
self
.
text_model
.
embeddings
.
token_embedding
=
value
@
add_start_docstrings_to_model_forward
(
CLIP_TEXT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
BaseModelOutputWithPooling
,
config_class
=
CLIPTextConfig
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPooling
]:
r
"""
Returns:
Examples:
```python
>>> from transformers import CLIPTokenizer, CLIPTextModel
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
return
self
.
text_model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
src/diffusers/models/unet_glide.py
View file @
9fdbc14e
...
@@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -435,7 +435,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
resblock_updown
=
False
,
encoder_channels
=
None
,
transformer_dim
=
512
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
register
(
...
@@ -455,7 +455,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -455,7 +455,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
num_heads_upsample
=
num_heads_upsample
,
num_heads_upsample
=
num_heads_upsample
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
resblock_updown
=
resblock_updown
,
resblock_updown
=
resblock_updown
,
encoder_channels
=
encoder_channels
,
transformer_dim
=
transformer_dim
,
)
)
if
num_heads_upsample
==
-
1
:
if
num_heads_upsample
==
-
1
:
...
@@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self
.
channel_mult
=
channel_mult
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
conv_resample
=
conv_resample
self
.
use_checkpoint
=
use_checkpoint
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
torch
.
float16
if
use_fp16
else
torch
.
float32
#
self.dtype = torch.float16 if use_fp16 else torch.float32
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
num_heads_upsample
=
num_heads_upsample
...
@@ -482,6 +482,8 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -482,6 +482,8 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
linear
(
time_embed_dim
,
time_embed_dim
),
linear
(
time_embed_dim
,
time_embed_dim
),
)
)
self
.
transformer_proj
=
nn
.
Linear
(
transformer_dim
,
self
.
model_channels
*
4
)
ch
=
input_ch
=
int
(
channel_mult
[
0
]
*
model_channels
)
ch
=
input_ch
=
int
(
channel_mult
[
0
]
*
model_channels
)
self
.
input_blocks
=
nn
.
ModuleList
([
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
ch
,
3
,
padding
=
1
))])
self
.
input_blocks
=
nn
.
ModuleList
([
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
ch
,
3
,
padding
=
1
))])
self
.
_feature_size
=
ch
self
.
_feature_size
=
ch
...
@@ -508,7 +510,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -508,7 +510,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
encoder_channels
,
encoder_channels
=
transformer_dim
,
)
)
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
@@ -551,7 +553,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -551,7 +553,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
encoder_channels
,
encoder_channels
=
transformer_dim
,
),
),
ResBlock
(
ResBlock
(
ch
,
ch
,
...
@@ -587,7 +589,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -587,7 +589,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads_upsample
,
num_heads
=
num_heads_upsample
,
num_head_channels
=
num_head_channels
,
num_head_channels
=
num_head_channels
,
encoder_channels
=
encoder_channels
,
encoder_channels
=
transformer_dim
,
)
)
)
)
if
level
and
i
==
num_res_blocks
:
if
level
and
i
==
num_res_blocks
:
...
@@ -642,10 +644,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -642,10 +644,6 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
:param y: an [N] Tensor of labels, if class-conditional.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
:return: an [N x C x ...] Tensor of outputs.
"""
"""
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
hs
=
[]
hs
=
[]
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
emb
=
self
.
time_embed
(
timestep_embedding
(
timesteps
,
self
.
model_channels
))
...
@@ -653,13 +651,15 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
...
@@ -653,13 +651,15 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
transformer_proj
=
self
.
transformer_proj
(
transformer_out
[:,
-
1
])
transformer_proj
=
self
.
transformer_proj
(
transformer_out
[:,
-
1
])
transformer_out
=
transformer_out
.
permute
(
0
,
2
,
1
)
# NLC -> NCL
transformer_out
=
transformer_out
.
permute
(
0
,
2
,
1
)
# NLC -> NCL
h
=
x
.
type
(
self
.
dtype
)
emb
=
emb
+
transformer_proj
.
to
(
emb
)
h
=
x
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
)
h
=
module
(
h
,
emb
,
transformer_out
)
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
)
h
=
self
.
middle_block
(
h
,
emb
,
transformer_out
)
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
other
=
hs
.
pop
(
)
h
=
module
(
h
,
emb
)
h
=
torch
.
cat
([
h
,
other
],
dim
=
1
)
h
=
h
.
type
(
x
.
dtype
)
h
=
module
(
h
,
emb
,
transformer_out
)
return
self
.
out
(
h
)
return
self
.
out
(
h
)
src/diffusers/models/unet_ldm.py
View file @
9fdbc14e
...
@@ -830,7 +830,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -830,7 +830,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self
.
conv_resample
=
conv_resample
self
.
conv_resample
=
conv_resample
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
use_checkpoint
=
use_checkpoint
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
dtype
_
=
torch
.
float16
if
use_fp16
else
torch
.
float32
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
num_heads_upsample
=
num_heads_upsample
...
@@ -1060,7 +1060,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
...
@@ -1060,7 +1060,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
assert
y
.
shape
==
(
x
.
shape
[
0
],)
assert
y
.
shape
==
(
x
.
shape
[
0
],)
emb
=
emb
+
self
.
label_emb
(
y
)
emb
=
emb
+
self
.
label_emb
(
y
)
h
=
x
.
type
(
self
.
dtype
)
h
=
x
.
type
(
self
.
dtype
_
)
for
module
in
self
.
input_blocks
:
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
)
hs
.
append
(
h
)
hs
.
append
(
h
)
...
...
src/diffusers/pipeline_utils.py
View file @
9fdbc14e
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
importlib
import
importlib
import
os
import
os
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
# CHANGE to diffusers.utils
# CHANGE to diffusers.utils
...
@@ -35,10 +36,12 @@ logger = logging.get_logger(__name__)
...
@@ -35,10 +36,12 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES
=
{
LOADABLE_CLASSES
=
{
"diffusers"
:
{
"diffusers"
:
{
"ModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"CLIPTextModel"
:
[
"save_pretrained"
,
"from_pretrained"
],
# TODO (Anton): move to transformers
"GaussianDDPMScheduler"
:
[
"save_config"
,
"from_config"
],
"GaussianDDPMScheduler"
:
[
"save_config"
,
"from_config"
],
"ClassifierFreeGuidanceScheduler"
:
[
"save_config"
,
"from_config"
],
},
},
"transformers"
:
{
"transformers"
:
{
"
ModelMixin
"
:
[
"save_pretrained"
,
"from_pretrained"
],
"
GPT2Tokenizer
"
:
[
"save_pretrained"
,
"from_pretrained"
],
},
},
}
}
...
@@ -62,7 +65,7 @@ class DiffusionPipeline(ConfigMixin):
...
@@ -62,7 +65,7 @@ class DiffusionPipeline(ConfigMixin):
# set models
# set models
setattr
(
self
,
name
,
module
)
setattr
(
self
,
name
,
module
)
register_dict
=
{
"_module"
:
self
.
__module__
.
split
(
"."
)[
-
1
]
+
".py"
}
register_dict
=
{
"_module"
:
self
.
__module__
.
split
(
"."
)[
-
1
]
+
".py"
}
self
.
register
(
**
register_dict
)
self
.
register
(
**
register_dict
)
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]):
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]):
...
...
src/diffusers/schedulers/__init__.py
View file @
9fdbc14e
...
@@ -16,4 +16,5 @@
...
@@ -16,4 +16,5 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.gaussian_ddpm
import
GaussianDDPMScheduler
from
.gaussian_ddpm
import
GaussianDDPMScheduler
src/diffusers/schedulers/classifier_free_guidance.py
0 → 100644
View file @
9fdbc14e
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
import
torch
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
def
linear_beta_schedule
(
timesteps
,
beta_start
,
beta_end
):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
np
.
array
(
betas
,
dtype
=
np
.
float64
)
class
ClassifierFreeGuidanceScheduler
(
nn
.
Module
,
ConfigMixin
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
timesteps
=
1000
,
beta_schedule
=
"squaredcos_cap_v2"
,
):
super
().
__init__
()
self
.
register
(
timesteps
=
timesteps
,
beta_schedule
=
beta_schedule
,
)
self
.
num_timesteps
=
int
(
timesteps
)
if
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
self
.
betas
=
betas_for_alpha_bar
(
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
alphas
=
1.0
-
self
.
betas
self
.
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
self
.
alphas_cumprod_prev
=
np
.
append
(
1.0
,
self
.
alphas_cumprod
[:
-
1
])
# calculations for diffusion q(x_t | x_{t-1}) and others
self
.
sqrt_recip_alphas_cumprod
=
np
.
sqrt
(
1.0
/
self
.
alphas_cumprod
)
self
.
sqrt_recipm1_alphas_cumprod
=
np
.
sqrt
(
1.0
/
self
.
alphas_cumprod
-
1
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self
.
posterior_variance
=
self
.
betas
*
(
1.0
-
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self
.
posterior_log_variance_clipped
=
np
.
log
(
np
.
append
(
self
.
posterior_variance
[
1
],
self
.
posterior_variance
[
1
:])
)
self
.
posterior_mean_coef1
=
self
.
betas
*
np
.
sqrt
(
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
self
.
posterior_mean_coef2
=
(
1.0
-
self
.
alphas_cumprod_prev
)
*
np
.
sqrt
(
alphas
)
/
(
1.0
-
self
.
alphas_cumprod
)
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
# always sample on CPU to be deterministic
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
return
self
.
num_timesteps
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment