Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
07ffe73f
"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fee93c81eb7c5e9fe1618f858f1e369567170edc"
Commit
07ffe73f
authored
Jun 08, 2022
by
anton-l
Browse files
Style
parent
bb98a5b7
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
91 additions
and
96 deletions
+91
-96
models/vision/glide/convert_weights.py
models/vision/glide/convert_weights.py
+8
-7
models/vision/glide/modeling_glide.py
models/vision/glide/modeling_glide.py
+16
-14
models/vision/glide/run_glide.py
models/vision/glide/run_glide.py
+2
-0
src/diffusers/__init__.py
src/diffusers/__init__.py
+2
-2
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+3
-7
src/diffusers/models/__init__.py
src/diffusers/models/__init__.py
+1
-1
src/diffusers/models/clip_text_transformer.py
src/diffusers/models/clip_text_transformer.py
+48
-50
src/diffusers/models/unet_glide.py
src/diffusers/models/unet_glide.py
+1
-1
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+2
-1
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-1
src/diffusers/schedulers/classifier_free_guidance.py
src/diffusers/schedulers/classifier_free_guidance.py
+7
-12
No files found.
models/vision/glide/convert_weights.py
View file @
07ffe73f
import
torch
from
torch
import
nn
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
from
diffusers
import
UNetGLIDEModel
,
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
UNetGLIDEModel
from
modeling_glide
import
GLIDE
from
transformers
import
CLIPTextConfig
,
GPT2Tokenizer
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt
state_dict
=
torch
.
load
(
"base.pt"
,
map_location
=
"cpu"
)
...
...
@@ -22,7 +23,7 @@ config = CLIPTextConfig(
)
model
=
CLIPTextModel
(
config
).
eval
()
tokenizer
=
GPT2Tokenizer
(
"./glide-base/vocab.json"
,
"./glide-base/merges.txt"
,
pad_token
=
"<|endoftext|>"
)
#tokenizer.save_pretrained("./glide-base")
#
tokenizer.save_pretrained("./glide-base")
hf_encoder
=
model
.
text_model
...
...
@@ -51,11 +52,11 @@ for layer_idx in range(config.num_hidden_layers):
hf_layer
.
mlp
.
fc2
.
weight
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.weight"
]
hf_layer
.
mlp
.
fc2
.
bias
=
state_dict
[
f
"transformer.resblocks.
{
layer_idx
}
.mlp.c_proj.bias"
]
#inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
#with torch.no_grad():
#
inputs = tokenizer(["an oil painting of a corgi", ""], padding="max_length", max_length=128, return_tensors="pt")
#
with torch.no_grad():
# outputs = model(**inputs)
#model.save_pretrained("./glide-base")
#
model.save_pretrained("./glide-base")
### Convert the UNet
...
...
@@ -80,4 +81,4 @@ scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="squar
glide
=
GLIDE
(
unet
=
unet_model
,
noise_scheduler
=
scheduler
,
text_encoder
=
model
,
tokenizer
=
tokenizer
)
glide
.
save_pretrained
(
"./glide-base"
)
\ No newline at end of file
glide
.
save_pretrained
(
"./glide-base"
)
models/vision/glide/modeling_glide.py
View file @
07ffe73f
...
...
@@ -14,12 +14,12 @@
# limitations under the License.
from
diffusers
import
DiffusionPipeline
,
UNetGLIDEModel
,
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
from
transformers
import
GPT2Tokenizer
import
numpy
as
np
import
torch
import
tqdm
import
torch
import
numpy
as
np
from
diffusers
import
ClassifierFreeGuidanceScheduler
,
CLIPTextModel
,
DiffusionPipeline
,
UNetGLIDEModel
from
transformers
import
GPT2Tokenizer
def
_extract_into_tensor
(
arr
,
timesteps
,
broadcast_shape
):
...
...
@@ -40,14 +40,16 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
class
GLIDE
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
:
UNetGLIDEModel
,
noise_scheduler
:
ClassifierFreeGuidanceScheduler
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
GPT2Tokenizer
self
,
unet
:
UNetGLIDEModel
,
noise_scheduler
:
ClassifierFreeGuidanceScheduler
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
GPT2Tokenizer
,
):
super
().
__init__
()
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
,
text_encoder
=
text_encoder
,
tokenizer
=
tokenizer
)
def
q_posterior_mean_variance
(
self
,
x_start
,
x_t
,
t
):
"""
...
...
@@ -129,7 +131,9 @@ class GLIDE(DiffusionPipeline):
self
.
text_encoder
.
to
(
torch_device
)
# 1. Sample gaussian noise
image
=
self
.
noise_scheduler
.
sample_noise
((
1
,
self
.
unet
.
in_channels
,
64
,
64
),
device
=
torch_device
,
generator
=
generator
)
image
=
self
.
noise_scheduler
.
sample_noise
(
(
1
,
self
.
unet
.
in_channels
,
64
,
64
),
device
=
torch_device
,
generator
=
generator
)
# 2. Encode tokens
# an empty input is needed to guide the model away from (
...
...
@@ -141,9 +145,7 @@ class GLIDE(DiffusionPipeline):
t
=
torch
.
tensor
([
i
]
*
image
.
shape
[
0
],
device
=
torch_device
)
mean
,
variance
,
log_variance
,
pred_xstart
=
self
.
p_mean_variance
(
self
.
unet
,
transformer_out
,
image
,
t
)
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
)
nonzero_mask
=
(
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
)
# no noise when t == 0
nonzero_mask
=
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
# no noise when t == 0
image
=
mean
+
nonzero_mask
*
torch
.
exp
(
0.5
*
log_variance
)
*
noise
return
image
models/vision/glide/run_glide.py
View file @
07ffe73f
import
torch
from
modeling_glide
import
GLIDE
generator
=
torch
.
Generator
()
generator
=
generator
.
manual_seed
(
0
)
...
...
src/diffusers/__init__.py
View file @
07ffe73f
...
...
@@ -5,10 +5,10 @@
__version__
=
"0.0.1"
from
.modeling_utils
import
ModelMixin
from
.models.clip_text_transformer
import
CLIPTextModel
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
UNetGLIDEModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.clip_text_transformer
import
CLIPTextModel
from
.pipeline_utils
import
DiffusionPipeline
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
src/diffusers/configuration_utils.py
View file @
07ffe73f
...
...
@@ -89,7 +89,6 @@ class ConfigMixin:
self
.
to_json_file
(
output_config_file
)
logger
.
info
(
f
"ConfigMixinuration saved in
{
output_config_file
}
"
)
@
classmethod
def
get_config_dict
(
...
...
@@ -183,7 +182,7 @@ class ConfigMixin:
logger
.
info
(
f
"loading configuration file
{
config_file
}
"
)
else
:
logger
.
info
(
f
"loading configuration file
{
config_file
}
from cache at
{
resolved_config_file
}
"
)
return
config_dict
@
classmethod
...
...
@@ -199,9 +198,8 @@ class ConfigMixin:
# use value from config dict
init_dict
[
key
]
=
config_dict
.
pop
(
key
)
unused_kwargs
=
config_dict
.
update
(
kwargs
)
passed_keys
=
set
(
init_dict
.
keys
())
if
len
(
expected_keys
-
passed_keys
)
>
0
:
logger
.
warn
(
...
...
@@ -212,9 +210,7 @@ class ConfigMixin:
@
classmethod
def
from_config
(
cls
,
pretrained_model_name_or_path
:
Union
[
str
,
os
.
PathLike
],
return_unused_kwargs
=
False
,
**
kwargs
):
config_dict
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
config_dict
=
cls
.
get_config_dict
(
pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
init_dict
,
unused_kwargs
=
cls
.
extract_init_dict
(
config_dict
,
**
kwargs
)
...
...
src/diffusers/models/__init__.py
View file @
07ffe73f
...
...
@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.clip_text_transformer
import
CLIPTextModel
from
.unet
import
UNetModel
from
.unet_glide
import
UNetGLIDEModel
from
.unet_ldm
import
UNetLDMModel
from
.clip_text_transformer
import
CLIPTextModel
src/diffusers/models/clip_text_transformer.py
View file @
07ffe73f
...
...
@@ -14,14 +14,15 @@
# limitations under the License.
""" PyTorch CLIP model."""
from
dataclasses
import
dataclass
import
math
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
,
Tuple
,
Union
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
transformers
import
CLIPConfig
,
CLIPModel
,
CLIPTextConfig
,
CLIPVisionConfig
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
transformers.modeling_utils
import
PreTrainedModel
...
...
@@ -32,7 +33,7 @@ from transformers.utils import (
logging
,
replace_return_docstrings
,
)
from
transformers
import
CLIPModel
,
CLIPConfig
,
CLIPVisionConfig
,
CLIPTextConfig
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -153,11 +154,11 @@ class CLIPTextEmbeddings(nn.Module):
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
seq_length
=
input_ids
.
shape
[
-
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
-
2
]
...
...
@@ -193,16 +194,15 @@ class CLIPAttention(nn.Module):
)
self
.
scale
=
1
/
math
.
sqrt
(
math
.
sqrt
(
self
.
head_dim
))
self
.
qkv_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
*
3
)
self
.
qkv_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
*
3
)
self
.
out_proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
causal_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
causal_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""Input shape: Batch x Time x Channel"""
...
...
@@ -212,9 +212,7 @@ class CLIPAttention(nn.Module):
qkv_states
=
qkv_states
.
view
(
bsz
,
tgt_len
,
self
.
num_heads
,
-
1
)
query_states
,
key_states
,
value_states
=
torch
.
split
(
qkv_states
,
self
.
head_dim
,
dim
=-
1
)
attn_weights
=
torch
.
einsum
(
"bthc,bshc->bhts"
,
query_states
*
self
.
scale
,
key_states
*
self
.
scale
)
attn_weights
=
torch
.
einsum
(
"bthc,bshc->bhts"
,
query_states
*
self
.
scale
,
key_states
*
self
.
scale
)
wdtype
=
attn_weights
.
dtype
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
.
float
(),
dim
=-
1
).
type
(
wdtype
)
...
...
@@ -252,11 +250,11 @@ class CLIPEncoderLayer(nn.Module):
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
causal_attention_mask
:
torch
.
Tensor
,
output_attentions
:
Optional
[
bool
]
=
False
,
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
causal_attention_mask
:
torch
.
Tensor
,
output_attentions
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
]:
"""
Args:
...
...
@@ -313,19 +311,19 @@ class CLIPPreTrainedModel(PreTrainedModel):
module
.
padding_embedding
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
0.02
)
elif
isinstance
(
module
,
CLIPVisionEmbeddings
):
factor
=
self
.
config
.
initializer_factor
nn
.
init
.
normal_
(
module
.
class_embedding
,
mean
=
0.0
,
std
=
module
.
embed_dim
**
-
0.5
*
factor
)
nn
.
init
.
normal_
(
module
.
class_embedding
,
mean
=
0.0
,
std
=
module
.
embed_dim
**-
0.5
*
factor
)
nn
.
init
.
normal_
(
module
.
patch_embedding
.
weight
,
std
=
module
.
config
.
initializer_range
*
factor
)
nn
.
init
.
normal_
(
module
.
position_embedding
.
weight
,
std
=
module
.
config
.
initializer_range
*
factor
)
elif
isinstance
(
module
,
CLIPAttention
):
factor
=
self
.
config
.
initializer_factor
in_proj_std
=
(
module
.
embed_dim
**
-
0.5
)
*
((
2
*
module
.
config
.
num_hidden_layers
)
**
-
0.5
)
*
factor
out_proj_std
=
(
module
.
embed_dim
**
-
0.5
)
*
factor
in_proj_std
=
(
module
.
embed_dim
**-
0.5
)
*
((
2
*
module
.
config
.
num_hidden_layers
)
**
-
0.5
)
*
factor
out_proj_std
=
(
module
.
embed_dim
**-
0.5
)
*
factor
nn
.
init
.
normal_
(
module
.
qkv_proj
.
weight
,
std
=
in_proj_std
)
nn
.
init
.
normal_
(
module
.
out_proj
.
weight
,
std
=
out_proj_std
)
elif
isinstance
(
module
,
CLIPMLP
):
factor
=
self
.
config
.
initializer_factor
in_proj_std
=
(
(
module
.
config
.
hidden_size
**
-
0.5
)
*
((
2
*
module
.
config
.
num_hidden_layers
)
**
-
0.5
)
*
factor
(
module
.
config
.
hidden_size
**-
0.5
)
*
((
2
*
module
.
config
.
num_hidden_layers
)
**
-
0.5
)
*
factor
)
fc_std
=
(
2
*
module
.
config
.
hidden_size
)
**
-
0.5
*
factor
nn
.
init
.
normal_
(
module
.
fc1
.
weight
,
std
=
fc_std
)
...
...
@@ -333,11 +331,11 @@ class CLIPPreTrainedModel(PreTrainedModel):
elif
isinstance
(
module
,
CLIPModel
):
nn
.
init
.
normal_
(
module
.
text_projection
.
weight
,
std
=
module
.
text_embed_dim
**
-
0.5
*
self
.
config
.
initializer_factor
,
std
=
module
.
text_embed_dim
**-
0.5
*
self
.
config
.
initializer_factor
,
)
nn
.
init
.
normal_
(
module
.
visual_projection
.
weight
,
std
=
module
.
vision_embed_dim
**
-
0.5
*
self
.
config
.
initializer_factor
,
std
=
module
.
vision_embed_dim
**-
0.5
*
self
.
config
.
initializer_factor
,
)
if
isinstance
(
module
,
nn
.
LayerNorm
):
...
...
@@ -463,13 +461,13 @@ class CLIPEncoder(nn.Module):
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
inputs_embeds
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
causal_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
self
,
inputs_embeds
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
causal_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutput
]:
r
"""
Args:
...
...
@@ -562,13 +560,13 @@ class CLIPTextTransformer(nn.Module):
@
add_start_docstrings_to_model_forward
(
CLIP_TEXT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
BaseModelOutputWithPooling
,
config_class
=
CLIPTextConfig
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPooling
]:
r
"""
Returns:
...
...
@@ -652,13 +650,13 @@ class CLIPTextModel(CLIPPreTrainedModel):
@
add_start_docstrings_to_model_forward
(
CLIP_TEXT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
BaseModelOutputWithPooling
,
config_class
=
CLIPTextConfig
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPooling
]:
r
"""
Returns:
...
...
@@ -684,4 +682,4 @@ class CLIPTextModel(CLIPPreTrainedModel):
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
\ No newline at end of file
)
src/diffusers/models/unet_glide.py
View file @
07ffe73f
...
...
@@ -470,7 +470,7 @@ class UNetGLIDEModel(ModelMixin, ConfigMixin):
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
use_checkpoint
=
use_checkpoint
#self.dtype = torch.float16 if use_fp16 else torch.float32
#
self.dtype = torch.float16 if use_fp16 else torch.float32
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
...
...
src/diffusers/pipeline_utils.py
View file @
07ffe73f
...
...
@@ -17,6 +17,7 @@
import
importlib
import
os
from
typing
import
Optional
,
Union
from
huggingface_hub
import
snapshot_download
# CHANGE to diffusers.utils
...
...
@@ -64,7 +65,7 @@ class DiffusionPipeline(ConfigMixin):
# set models
setattr
(
self
,
name
,
module
)
register_dict
=
{
"_module"
:
self
.
__module__
.
split
(
"."
)[
-
1
]
+
".py"
}
register_dict
=
{
"_module"
:
self
.
__module__
.
split
(
"."
)[
-
1
]
+
".py"
}
self
.
register
(
**
register_dict
)
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]):
...
...
src/diffusers/schedulers/__init__.py
View file @
07ffe73f
...
...
@@ -16,5 +16,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.gaussian_ddpm
import
GaussianDDPMScheduler
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.gaussian_ddpm
import
GaussianDDPMScheduler
src/diffusers/schedulers/classifier_free_guidance.py
View file @
07ffe73f
...
...
@@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
math
from
torch
import
nn
import
numpy
as
np
import
torch
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
...
...
@@ -80,19 +81,13 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
self
.
sqrt_recipm1_alphas_cumprod
=
np
.
sqrt
(
1.0
/
self
.
alphas_cumprod
-
1
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self
.
posterior_variance
=
(
betas
*
(
1.0
-
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
)
self
.
posterior_variance
=
betas
*
(
1.0
-
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self
.
posterior_log_variance_clipped
=
np
.
log
(
np
.
append
(
self
.
posterior_variance
[
1
],
self
.
posterior_variance
[
1
:])
)
self
.
posterior_mean_coef1
=
(
betas
*
np
.
sqrt
(
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
)
self
.
posterior_mean_coef2
=
(
(
1.0
-
self
.
alphas_cumprod_prev
)
*
np
.
sqrt
(
alphas
)
/
(
1.0
-
self
.
alphas_cumprod
)
np
.
append
(
self
.
posterior_variance
[
1
],
self
.
posterior_variance
[
1
:])
)
self
.
posterior_mean_coef1
=
betas
*
np
.
sqrt
(
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
self
.
posterior_mean_coef2
=
(
1.0
-
self
.
alphas_cumprod_prev
)
*
np
.
sqrt
(
alphas
)
/
(
1.0
-
self
.
alphas_cumprod
)
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
# always sample on CPU to be deterministic
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment