Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
12b10cbe
Commit
12b10cbe
authored
Jun 12, 2022
by
Patrick von Platen
Browse files
finish refactor
parent
2d97544d
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
236 additions
and
185 deletions
+236
-185
Makefile
Makefile
+1
-1
src/diffusers/__init__.py
src/diffusers/__init__.py
+4
-3
src/diffusers/configuration_utils.py
src/diffusers/configuration_utils.py
+1
-1
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+3
-3
src/diffusers/pipelines/__init__.py
src/diffusers/pipelines/__init__.py
+1
-1
src/diffusers/pipelines/configuration_ldmbert.py
src/diffusers/pipelines/configuration_ldmbert.py
+1
-1
src/diffusers/pipelines/modeling_vae.py
src/diffusers/pipelines/modeling_vae.py
+15
-14
src/diffusers/pipelines/old/latent_diffusion/configuration_ldmbert.py
...s/pipelines/old/latent_diffusion/configuration_ldmbert.py
+1
-1
src/diffusers/pipelines/old/latent_diffusion/modeling_latent_diffusion.py
...pelines/old/latent_diffusion/modeling_latent_diffusion.py
+29
-16
src/diffusers/pipelines/old/latent_diffusion/modeling_ldmbert.py
...fusers/pipelines/old/latent_diffusion/modeling_ldmbert.py
+6
-5
src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py
src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py
+15
-14
src/diffusers/pipelines/pipeline_ddim.py
src/diffusers/pipelines/pipeline_ddim.py
+5
-3
src/diffusers/pipelines/pipeline_ddpm.py
src/diffusers/pipelines/pipeline_ddpm.py
+5
-3
src/diffusers/pipelines/pipeline_glide.py
src/diffusers/pipelines/pipeline_glide.py
+4
-4
src/diffusers/pipelines/pipeline_latent_diffusion.py
src/diffusers/pipelines/pipeline_latent_diffusion.py
+46
-29
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+2
-1
src/diffusers/schedulers/ddim.py
src/diffusers/schedulers/ddim.py
+30
-41
src/diffusers/schedulers/gaussian_ddpm.py
src/diffusers/schedulers/gaussian_ddpm.py
+35
-41
src/diffusers/schedulers/schedulers_utils.py
src/diffusers/schedulers/schedulers_utils.py
+30
-2
src/diffusers/testing_utils.py
src/diffusers/testing_utils.py
+2
-1
No files found.
Makefile
View file @
12b10cbe
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export
PYTHONPATH
=
src
export
PYTHONPATH
=
src
check_dirs
:=
models
tests src utils
check_dirs
:=
tests src utils
modified_only_fixup
:
modified_only_fixup
:
$(
eval
modified_py_files :
=
$(
shell
python utils/get_modified_files.py
$(check_dirs)
))
$(
eval
modified_py_files :
=
$(
shell
python utils/get_modified_files.py
$(check_dirs)
))
...
...
src/diffusers/__init__.py
View file @
12b10cbe
...
@@ -2,15 +2,16 @@
...
@@ -2,15 +2,16 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# module, but to preserve other warnings. So, don't check this module at all.
__version__
=
"0.0.
1
"
__version__
=
"0.0.
3
"
from
.modeling_utils
import
ModelMixin
from
.modeling_utils
import
ModelMixin
from
.models.unet
import
UNetModel
from
.models.unet
import
UNetModel
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_glide
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
from
.schedulers
import
SchedulerMixin
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
from
.schedulers.ddim
import
DDIMScheduler
from
.schedulers.ddim
import
DDIMScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
from
.schedulers.glide_ddim
import
GlideDDIMScheduler
from
.schedulers.glide_ddim
import
GlideDDIMScheduler
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
src/diffusers/configuration_utils.py
View file @
12b10cbe
...
@@ -213,7 +213,7 @@ class ConfigMixin:
...
@@ -213,7 +213,7 @@ class ConfigMixin:
passed_keys
=
set
(
init_dict
.
keys
())
passed_keys
=
set
(
init_dict
.
keys
())
if
len
(
expected_keys
-
passed_keys
)
>
0
:
if
len
(
expected_keys
-
passed_keys
)
>
0
:
logger
.
warn
(
logger
.
warn
ing
(
f
"
{
expected_keys
-
passed_keys
}
was not found in config. Values will be initialized to default values."
f
"
{
expected_keys
-
passed_keys
}
was not found in config. Values will be initialized to default values."
)
)
...
...
src/diffusers/modeling_utils.py
View file @
12b10cbe
...
@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module):
raise
RuntimeError
(
f
"Error(s) in loading state_dict for
{
model
.
__class__
.
__name__
}
:
\n\t
{
error_msg
}
"
)
raise
RuntimeError
(
f
"Error(s) in loading state_dict for
{
model
.
__class__
.
__name__
}
:
\n\t
{
error_msg
}
"
)
if
len
(
unexpected_keys
)
>
0
:
if
len
(
unexpected_keys
)
>
0
:
logger
.
warning
(
logger
.
warning
ing
(
f
"Some weights of the model checkpoint at
{
pretrained_model_name_or_path
}
were not used when"
f
"Some weights of the model checkpoint at
{
pretrained_model_name_or_path
}
were not used when"
f
" initializing
{
model
.
__class__
.
__name__
}
:
{
unexpected_keys
}
\n
- This IS expected if you are"
f
" initializing
{
model
.
__class__
.
__name__
}
:
{
unexpected_keys
}
\n
- This IS expected if you are"
f
" initializing
{
model
.
__class__
.
__name__
}
from the checkpoint of a model trained on another task or"
f
" initializing
{
model
.
__class__
.
__name__
}
from the checkpoint of a model trained on another task or"
...
@@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module):
else
:
else
:
logger
.
info
(
f
"All model checkpoint weights were used when initializing
{
model
.
__class__
.
__name__
}
.
\n
"
)
logger
.
info
(
f
"All model checkpoint weights were used when initializing
{
model
.
__class__
.
__name__
}
.
\n
"
)
if
len
(
missing_keys
)
>
0
:
if
len
(
missing_keys
)
>
0
:
logger
.
warning
(
logger
.
warning
ing
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized:
{
missing_keys
}
\n
You should probably"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized:
{
missing_keys
}
\n
You should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
...
@@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module):
for
key
,
shape1
,
shape2
in
mismatched_keys
for
key
,
shape1
,
shape2
in
mismatched_keys
]
]
)
)
logger
.
warning
(
logger
.
warning
ing
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized because the shapes did not"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized because the shapes did not"
f
" match:
\n
{
mismatched_warning
}
\n
You should probably TRAIN this model on a down-stream task to be able"
f
" match:
\n
{
mismatched_warning
}
\n
You should probably TRAIN this model on a down-stream task to be able"
...
...
src/diffusers/pipelines/__init__.py
View file @
12b10cbe
from
.pipeline_ddim
import
DDIM
from
.pipeline_ddim
import
DDIM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_ddpm
import
DDPM
from
.pipeline_latent_diffusion
import
LatentDiffusion
from
.pipeline_glide
import
GLIDE
from
.pipeline_glide
import
GLIDE
from
.pipeline_latent_diffusion
import
LatentDiffusion
src/diffusers/pipelines/configuration_ldmbert.py
View file @
12b10cbe
...
@@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig):
...
@@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig):
scale_embedding
=
False
,
scale_embedding
=
False
,
use_cache
=
True
,
use_cache
=
True
,
pad_token_id
=
0
,
pad_token_id
=
0
,
**
kwargs
**
kwargs
,
):
):
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
...
...
src/diffusers/pipelines/modeling_vae.py
View file @
12b10cbe
...
@@ -2,10 +2,10 @@
...
@@ -2,10 +2,10 @@
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
tqdm
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
tqdm
from
diffusers
import
DiffusionPipeline
from
diffusers
import
DiffusionPipeline
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.modeling_utils
import
ModelMixin
from
diffusers.modeling_utils
import
ModelMixin
...
@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object):
...
@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object):
def
kl
(
self
,
other
=
None
):
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
return
torch
.
Tensor
([
0.
0
])
else
:
else
:
if
other
is
None
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
else
:
else
:
return
0.5
*
torch
.
sum
(
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
+
self
.
var
/
other
.
var
dim
=
[
1
,
2
,
3
])
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
return
torch
.
Tensor
([
0.
0
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
def
mode
(
self
):
def
mode
(
self
):
return
self
.
mean
return
self
.
mean
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end
=
give_pre_end
,
give_pre_end
=
give_pre_end
,
)
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
def
encode
(
self
,
x
):
def
encode
(
self
,
x
):
...
@@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
else
:
else
:
z
=
posterior
.
mode
()
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
)
dec
=
self
.
decode
(
z
)
return
dec
,
posterior
return
dec
,
posterior
\ No newline at end of file
src/diffusers/pipelines/old/latent_diffusion/configuration_ldmbert.py
View file @
12b10cbe
...
@@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig):
...
@@ -123,7 +123,7 @@ class LDMBertConfig(PretrainedConfig):
scale_embedding
=
False
,
scale_embedding
=
False
,
use_cache
=
True
,
use_cache
=
True
,
pad_token_id
=
0
,
pad_token_id
=
0
,
**
kwargs
**
kwargs
,
):
):
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
...
...
src/diffusers/pipelines/old/latent_diffusion/modeling_latent_diffusion.py
View file @
12b10cbe
import
tqdm
import
torch
import
torch
import
tqdm
from
diffusers
import
DiffusionPipeline
from
diffusers
import
DiffusionPipeline
from
.configuration_ldmbert
import
LDMBertConfig
# NOQA
from
.modeling_ldmbert
import
LDMBertModel
# NOQA
# add these relative imports here, so we can load from hub
# add these relative imports here, so we can load from hub
from
.modeling_vae
import
AutoencoderKL
# NOQA
from
.modeling_vae
import
AutoencoderKL
# NOQA
from
.configuration_ldmbert
import
LDMBertConfig
# NOQA
from
.modeling_ldmbert
import
LDMBertModel
# NOQA
class
LatentDiffusion
(
DiffusionPipeline
):
class
LatentDiffusion
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
noise_scheduler
):
...
@@ -14,7 +16,16 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -14,7 +16,16 @@ class LatentDiffusion(DiffusionPipeline):
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
):
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
,
):
# eta corresponds to η in paper and should be between [0, 1]
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
if
torch_device
is
None
:
...
@@ -23,16 +34,18 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -23,16 +34,18 @@ class LatentDiffusion(DiffusionPipeline):
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
# get unconditional embeddings for classifier free guidence
# get unconditional embeddings for classifier free guidence
if
guidance_scale
!=
1.0
:
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)[
0
]
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)[
0
]
# get text embedding
# get text embedding
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'
pt
'
).
to
(
torch_device
)
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"
pt
"
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)[
0
]
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)[
0
]
num_trained_timesteps
=
self
.
noise_scheduler
.
timesteps
num_trained_timesteps
=
self
.
noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
...
@@ -41,7 +54,7 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -41,7 +54,7 @@ class LatentDiffusion(DiffusionPipeline):
device
=
torch_device
,
device
=
torch_device
,
generator
=
generator
,
generator
=
generator
,
)
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Ideally, read DDIM paper in-detail understanding
...
@@ -60,7 +73,7 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -60,7 +73,7 @@ class LatentDiffusion(DiffusionPipeline):
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
else
:
else
:
# for classifier free guidance, we need to do two forward passes
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
# to avoid doing two forward passes
image_in
=
torch
.
cat
([
image
]
*
2
)
image_in
=
torch
.
cat
([
image
]
*
2
)
context
=
torch
.
cat
([
uncond_embeddings
,
text_embedding
])
context
=
torch
.
cat
([
uncond_embeddings
,
text_embedding
])
...
@@ -68,12 +81,12 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -68,12 +81,12 @@ class LatentDiffusion(DiffusionPipeline):
# 1. predict noise residual
# 1. predict noise residual
pred_noise_t
=
self
.
unet
(
image_in
,
timesteps
,
context
=
context
)
pred_noise_t
=
self
.
unet
(
image_in
,
timesteps
,
context
=
context
)
# perform guidance
# perform guidance
if
guidance_scale
!=
1.0
:
if
guidance_scale
!=
1.0
:
pred_noise_t_uncond
,
pred_noise_t
=
pred_noise_t
.
chunk
(
2
)
pred_noise_t_uncond
,
pred_noise_t
=
pred_noise_t
.
chunk
(
2
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
# 2. predict previous mean of image x_t-1
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
pred_prev_image
=
self
.
noise_scheduler
.
step
(
pred_noise_t
,
image
,
t
,
num_inference_steps
,
eta
)
...
@@ -87,8 +100,8 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -87,8 +100,8 @@ class LatentDiffusion(DiffusionPipeline):
image
=
pred_prev_image
+
variance
image
=
pred_prev_image
+
variance
# scale and decode image with vae
# scale and decode image with vae
image
=
1
/
0.18215
*
image
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
return
image
src/diffusers/pipelines/old/latent_diffusion/modeling_ldmbert.py
View file @
12b10cbe
...
@@ -43,6 +43,7 @@ from transformers.utils import (
...
@@ -43,6 +43,7 @@ from transformers.utils import (
logging
,
logging
,
replace_return_docstrings
,
replace_return_docstrings
,
)
)
from
.configuration_ldmbert
import
LDMBertConfig
from
.configuration_ldmbert
import
LDMBertConfig
...
@@ -662,7 +663,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
...
@@ -662,7 +663,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
model
=
LDMBertEncoder
(
config
)
self
.
model
=
LDMBertEncoder
(
config
)
self
.
to_logits
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
)
self
.
to_logits
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
=
None
,
input_ids
=
None
,
...
@@ -674,7 +675,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
...
@@ -674,7 +675,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
output_attentions
=
None
,
output_attentions
=
None
,
output_hidden_states
=
None
,
output_hidden_states
=
None
,
return_dict
=
None
,
return_dict
=
None
,
):
):
outputs
=
self
.
model
(
outputs
=
self
.
model
(
input_ids
,
input_ids
,
...
@@ -689,15 +690,15 @@ class LDMBertModel(LDMBertPreTrainedModel):
...
@@ -689,15 +690,15 @@ class LDMBertModel(LDMBertPreTrainedModel):
sequence_output
=
outputs
[
0
]
sequence_output
=
outputs
[
0
]
# logits = self.to_logits(sequence_output)
# logits = self.to_logits(sequence_output)
# outputs = (logits,) + outputs[1:]
# outputs = (logits,) + outputs[1:]
# if labels is not None:
# if labels is not None:
# loss_fct = CrossEntropyLoss()
# loss_fct = CrossEntropyLoss()
# loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
# loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
# outputs = (loss,) + outputs
# outputs = (loss,) + outputs
# if not return_dict:
# if not return_dict:
# return outputs
# return outputs
return
BaseModelOutput
(
return
BaseModelOutput
(
last_hidden_state
=
sequence_output
,
last_hidden_state
=
sequence_output
,
# hidden_states=outputs[1],
# hidden_states=outputs[1],
...
...
src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py
View file @
12b10cbe
...
@@ -2,10 +2,10 @@
...
@@ -2,10 +2,10 @@
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
tqdm
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
tqdm
from
diffusers
import
DiffusionPipeline
from
diffusers
import
DiffusionPipeline
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.modeling_utils
import
ModelMixin
from
diffusers.modeling_utils
import
ModelMixin
...
@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object):
...
@@ -740,29 +740,30 @@ class DiagonalGaussianDistribution(object):
def
kl
(
self
,
other
=
None
):
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
return
torch
.
Tensor
([
0.
0
])
else
:
else
:
if
other
is
None
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
else
:
else
:
return
0.5
*
torch
.
sum
(
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
+
self
.
var
/
other
.
var
dim
=
[
1
,
2
,
3
])
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
return
torch
.
Tensor
([
0.
0
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
def
mode
(
self
):
def
mode
(
self
):
return
self
.
mean
return
self
.
mean
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -834,7 +835,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end
=
give_pre_end
,
give_pre_end
=
give_pre_end
,
)
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
def
encode
(
self
,
x
):
def
encode
(
self
,
x
):
...
@@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -855,4 +856,4 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
else
:
else
:
z
=
posterior
.
mode
()
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
)
dec
=
self
.
decode
(
z
)
return
dec
,
posterior
return
dec
,
posterior
\ No newline at end of file
src/diffusers/pipelines/pipeline_ddim.py
View file @
12b10cbe
...
@@ -17,12 +17,14 @@
...
@@ -17,12 +17,14 @@
import
torch
import
torch
import
tqdm
import
tqdm
from
..pipeline_utils
import
DiffusionPipeline
from
..pipeline_utils
import
DiffusionPipeline
class
DDIM
(
DiffusionPipeline
):
class
DDIM
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
super
().
__init__
()
super
().
__init__
()
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
...
@@ -36,11 +38,11 @@ class DDIM(DiffusionPipeline):
...
@@ -36,11 +38,11 @@ class DDIM(DiffusionPipeline):
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
# Sample gaussian noise to begin loop
# Sample gaussian noise to begin loop
image
=
self
.
noise_scheduler
.
sample_noise
(
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
,
generator
=
generator
,
)
)
image
=
image
.
to
(
torch_device
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Ideally, read DDIM paper in-detail understanding
...
@@ -63,7 +65,7 @@ class DDIM(DiffusionPipeline):
...
@@ -63,7 +65,7 @@ class DDIM(DiffusionPipeline):
# 3. optionally sample variance
# 3. optionally sample variance
variance
=
0
variance
=
0
if
eta
>
0
:
if
eta
>
0
:
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
image
.
devic
e
,
generator
=
generator
)
noise
=
torch
.
randn
(
image
.
shap
e
,
generator
=
generator
)
.
to
(
image
.
device
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
variance
=
self
.
noise_scheduler
.
get_variance
(
t
,
num_inference_steps
).
sqrt
()
*
eta
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
# 4. set current image to prev_image: x_t -> x_t-1
...
...
src/diffusers/pipelines/pipeline_ddpm.py
View file @
12b10cbe
...
@@ -17,12 +17,14 @@
...
@@ -17,12 +17,14 @@
import
torch
import
torch
import
tqdm
import
tqdm
from
..pipeline_utils
import
DiffusionPipeline
from
..pipeline_utils
import
DiffusionPipeline
class
DDPM
(
DiffusionPipeline
):
class
DDPM
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
super
().
__init__
()
super
().
__init__
()
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
):
...
@@ -32,11 +34,11 @@ class DDPM(DiffusionPipeline):
...
@@ -32,11 +34,11 @@ class DDPM(DiffusionPipeline):
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
# Sample gaussian noise to begin loop
# Sample gaussian noise to begin loop
image
=
self
.
noise_scheduler
.
sample_noise
(
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
device
=
torch_device
,
generator
=
generator
,
generator
=
generator
,
)
)
image
=
image
.
to
(
torch_device
)
num_prediction_steps
=
len
(
self
.
noise_scheduler
)
num_prediction_steps
=
len
(
self
.
noise_scheduler
)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_prediction_steps
)),
total
=
num_prediction_steps
):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_prediction_steps
)),
total
=
num_prediction_steps
):
...
@@ -50,7 +52,7 @@ class DDPM(DiffusionPipeline):
...
@@ -50,7 +52,7 @@ class DDPM(DiffusionPipeline):
# 3. optionally sample variance
# 3. optionally sample variance
variance
=
0
variance
=
0
if
t
>
0
:
if
t
>
0
:
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
image
.
devic
e
,
generator
=
generator
)
noise
=
torch
.
randn
(
image
.
shap
e
,
generator
=
generator
)
.
to
(
image
.
device
)
variance
=
self
.
noise_scheduler
.
get_variance
(
t
).
sqrt
()
*
noise
variance
=
self
.
noise_scheduler
.
get_variance
(
t
).
sqrt
()
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
# 4. set current image to prev_image: x_t -> x_t-1
...
...
src/diffusers/pipelines/pipeline_glide.py
View file @
12b10cbe
...
@@ -24,10 +24,6 @@ import torch.utils.checkpoint
...
@@ -24,10 +24,6 @@ import torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
import
tqdm
import
tqdm
from
..pipeline_utils
import
DiffusionPipeline
from
..models
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
..schedulers
import
ClassifierFreeGuidanceScheduler
,
GlideDDIMScheduler
from
transformers
import
CLIPConfig
,
CLIPModel
,
CLIPTextConfig
,
CLIPVisionConfig
,
GPT2Tokenizer
from
transformers
import
CLIPConfig
,
CLIPModel
,
CLIPTextConfig
,
CLIPVisionConfig
,
GPT2Tokenizer
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
...
@@ -40,6 +36,10 @@ from transformers.utils import (
...
@@ -40,6 +36,10 @@ from transformers.utils import (
replace_return_docstrings
,
replace_return_docstrings
,
)
)
from
..models
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
..pipeline_utils
import
DiffusionPipeline
from
..schedulers
import
ClassifierFreeGuidanceScheduler
,
GlideDDIMScheduler
#####################
#####################
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
...
...
src/diffusers/pipelines/pipeline_latent_diffusion.py
View file @
12b10cbe
...
@@ -2,13 +2,14 @@
...
@@ -2,13 +2,14 @@
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
tqdm
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
..pipeline_utils
import
DiffusionPipeline
import
tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
from
..pipeline_utils
import
DiffusionPipeline
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
def
get_timestep_embedding
(
timesteps
,
embedding_dim
):
...
@@ -740,29 +741,30 @@ class DiagonalGaussianDistribution(object):
...
@@ -740,29 +741,30 @@ class DiagonalGaussianDistribution(object):
def
kl
(
self
,
other
=
None
):
def
kl
(
self
,
other
=
None
):
if
self
.
deterministic
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
return
torch
.
Tensor
([
0.
0
])
else
:
else
:
if
other
is
None
:
if
other
is
None
:
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
,
2
)
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
+
self
.
var
-
1.0
-
self
.
logvar
,
dim
=
[
1
,
2
,
3
])
else
:
else
:
return
0.5
*
torch
.
sum
(
return
0.5
*
torch
.
sum
(
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
torch
.
pow
(
self
.
mean
-
other
.
mean
,
2
)
/
other
.
var
+
self
.
var
/
other
.
var
-
1.0
-
self
.
logvar
+
other
.
logvar
,
+
self
.
var
/
other
.
var
dim
=
[
1
,
2
,
3
])
-
1.0
-
self
.
logvar
+
other
.
logvar
,
dim
=
[
1
,
2
,
3
],
)
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
def
nll
(
self
,
sample
,
dims
=
[
1
,
2
,
3
]):
if
self
.
deterministic
:
if
self
.
deterministic
:
return
torch
.
Tensor
([
0.
])
return
torch
.
Tensor
([
0.
0
])
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
logtwopi
=
np
.
log
(
2.0
*
np
.
pi
)
return
0.5
*
torch
.
sum
(
return
0.5
*
torch
.
sum
(
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
logtwopi
+
self
.
logvar
+
torch
.
pow
(
sample
-
self
.
mean
,
2
)
/
self
.
var
,
dim
=
dims
)
def
mode
(
self
):
def
mode
(
self
):
return
self
.
mean
return
self
.
mean
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
class
AutoencoderKL
(
ModelMixin
,
ConfigMixin
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -834,7 +836,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -834,7 +836,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
give_pre_end
=
give_pre_end
,
give_pre_end
=
give_pre_end
,
)
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
z_channels
,
2
*
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
z_channels
,
1
)
def
encode
(
self
,
x
):
def
encode
(
self
,
x
):
...
@@ -861,10 +863,20 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -861,10 +863,20 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
class
LatentDiffusion
(
DiffusionPipeline
):
class
LatentDiffusion
(
DiffusionPipeline
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
vqvae
,
bert
,
tokenizer
,
unet
,
noise_scheduler
):
super
().
__init__
()
super
().
__init__
()
noise_scheduler
=
noise_scheduler
.
set_format
(
"pt"
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
vqvae
=
vqvae
,
bert
=
bert
,
tokenizer
=
tokenizer
,
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
):
def
__call__
(
self
,
prompt
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
guidance_scale
=
1.0
,
num_inference_steps
=
50
,
):
# eta corresponds to η in paper and should be between [0, 1]
# eta corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
if
torch_device
is
None
:
...
@@ -873,25 +885,26 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -873,25 +885,26 @@ class LatentDiffusion(DiffusionPipeline):
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
vqvae
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
self
.
bert
.
to
(
torch_device
)
# get unconditional embeddings for classifier free guidence
# get unconditional embeddings for classifier free guidence
if
guidance_scale
!=
1.0
:
if
guidance_scale
!=
1.0
:
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'pt'
).
to
(
torch_device
)
uncond_input
=
self
.
tokenizer
([
""
],
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"pt"
).
to
(
torch_device
)
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)[
0
]
uncond_embeddings
=
self
.
bert
(
uncond_input
.
input_ids
)[
0
]
# get text embedding
# get text embedding
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
'
pt
'
).
to
(
torch_device
)
text_input
=
self
.
tokenizer
(
prompt
,
padding
=
"max_length"
,
max_length
=
77
,
return_tensors
=
"
pt
"
).
to
(
torch_device
)
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)[
0
]
text_embedding
=
self
.
bert
(
text_input
.
input_ids
)[
0
]
num_trained_timesteps
=
self
.
noise_scheduler
.
timesteps
num_trained_timesteps
=
self
.
noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
image
=
self
.
noise_scheduler
.
sample_noise
(
image
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
image_size
,
self
.
unet
.
image_size
),
device
=
torch_device
,
generator
=
generator
,
generator
=
generator
,
)
)
image
=
image
.
to
(
torch_device
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Ideally, read DDIM paper in-detail understanding
...
@@ -910,7 +923,7 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -910,7 +923,7 @@ class LatentDiffusion(DiffusionPipeline):
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
timesteps
=
torch
.
tensor
([
inference_step_times
[
t
]]
*
image
.
shape
[
0
],
device
=
torch_device
)
else
:
else
:
# for classifier free guidance, we need to do two forward passes
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
# to avoid doing two forward passes
image_in
=
torch
.
cat
([
image
]
*
2
)
image_in
=
torch
.
cat
([
image
]
*
2
)
context
=
torch
.
cat
([
uncond_embeddings
,
text_embedding
])
context
=
torch
.
cat
([
uncond_embeddings
,
text_embedding
])
...
@@ -918,12 +931,12 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -918,12 +931,12 @@ class LatentDiffusion(DiffusionPipeline):
# 1. predict noise residual
# 1. predict noise residual
pred_noise_t
=
self
.
unet
(
image_in
,
timesteps
,
context
=
context
)
pred_noise_t
=
self
.
unet
(
image_in
,
timesteps
,
context
=
context
)
# perform guidance
# perform guidance
if
guidance_scale
!=
1.0
:
if
guidance_scale
!=
1.0
:
pred_noise_t_uncond
,
pred_noise_t
=
pred_noise_t
.
chunk
(
2
)
pred_noise_t_uncond
,
pred_noise_t
=
pred_noise_t
.
chunk
(
2
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
pred_noise_t
=
pred_noise_t_uncond
+
guidance_scale
*
(
pred_noise_t
-
pred_noise_t_uncond
)
# 2. get actual t and t-1
# 2. get actual t and t-1
train_step
=
inference_step_times
[
t
]
train_step
=
inference_step_times
[
t
]
prev_train_step
=
inference_step_times
[
t
-
1
]
if
t
>
0
else
-
1
prev_train_step
=
inference_step_times
[
t
-
1
]
if
t
>
0
else
-
1
...
@@ -953,7 +966,11 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -953,7 +966,11 @@ class LatentDiffusion(DiffusionPipeline):
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
# Note: eta = 1.0 essentially corresponds to DDPM
# Note: eta = 1.0 essentially corresponds to DDPM
if
eta
>
0.0
:
if
eta
>
0.0
:
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
image
.
device
,
generator
=
generator
)
noise
=
torch
.
randn
(
(
batch_size
,
self
.
unet
.
in_channels
,
self
.
unet
.
resolution
,
self
.
unet
.
resolution
),
generator
=
generator
,
)
noise
=
noise
.
to
(
torch_device
)
prev_image
=
pred_prev_image
+
std_dev_t
*
noise
prev_image
=
pred_prev_image
+
std_dev_t
*
noise
else
:
else
:
prev_image
=
pred_prev_image
prev_image
=
pred_prev_image
...
@@ -962,8 +979,8 @@ class LatentDiffusion(DiffusionPipeline):
...
@@ -962,8 +979,8 @@ class LatentDiffusion(DiffusionPipeline):
image
=
prev_image
image
=
prev_image
# scale and decode image with vae
# scale and decode image with vae
image
=
1
/
0.18215
*
image
image
=
1
/
0.18215
*
image
image
=
self
.
vqvae
.
decode
(
image
)
image
=
self
.
vqvae
.
decode
(
image
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
image
=
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
return
image
return
image
src/diffusers/schedulers/__init__.py
View file @
12b10cbe
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
# limitations under the License.
# limitations under the License.
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.gaussian_ddpm
import
GaussianDDPMScheduler
from
.ddim
import
DDIMScheduler
from
.ddim
import
DDIMScheduler
from
.gaussian_ddpm
import
GaussianDDPMScheduler
from
.glide_ddim
import
GlideDDIMScheduler
from
.glide_ddim
import
GlideDDIMScheduler
from
.schedulers_utils
import
SchedulerMixin
src/diffusers/schedulers/ddim.py
View file @
12b10cbe
...
@@ -13,20 +13,13 @@
...
@@ -13,20 +13,13 @@
# limitations under the License.
# limitations under the License.
import
math
import
math
import
torch
import
numpy
as
np
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.schedulers_utils
import
betas_for_alpha_bar
,
linear_beta_schedule
from
.schedulers_utils
import
SchedulerMixin
,
betas_for_alpha_bar
,
linear_beta_schedule
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
class
DDIMScheduler
(
SchedulerMixin
,
ConfigMixin
):
class
DDIMScheduler
(
nn
.
Module
,
ConfigMixin
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
def
__init__
(
self
,
self
,
timesteps
=
1000
,
timesteps
=
1000
,
...
@@ -34,6 +27,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
...
@@ -34,6 +27,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
beta_end
=
0.02
,
beta_end
=
0.02
,
beta_schedule
=
"linear"
,
beta_schedule
=
"linear"
,
clip_predicted_image
=
True
,
clip_predicted_image
=
True
,
tensor_format
=
"np"
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
register
(
...
@@ -46,35 +40,34 @@ class DDIMScheduler(nn.Module, ConfigMixin):
...
@@ -46,35 +40,34 @@ class DDIMScheduler(nn.Module, ConfigMixin):
self
.
clip_image
=
clip_predicted_image
self
.
clip_image
=
clip_predicted_image
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
self
.
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
elif
beta_schedule
==
"squaredcos_cap_v2"
:
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
# GLIDE cosine schedule
betas
=
betas_for_alpha_bar
(
self
.
betas
=
betas_for_alpha_bar
(
timesteps
,
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
)
else
:
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
alphas
=
1.0
-
betas
self
.
alphas
=
1.0
-
self
.
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
self
.
one
=
np
.
array
(
1.0
)
self
.
register_buffer
(
"betas"
,
betas
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"alphas"
,
alphas
.
to
(
torch
.
float32
))
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
register_buffer
(
"alphas_cumprod"
,
alphas_cumprod
.
to
(
torch
.
float32
))
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
# if variance_type == "fixed_small":
# log_variance = torch.log(variance.clamp(min=1e-20))
# log_variance = torch.log(variance.clamp(min=1e-20))
# elif variance_type == "fixed_large":
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def
get_alpha
(
self
,
time_step
):
def
get_alpha
(
self
,
time_step
):
return
self
.
alphas
[
time_step
]
return
self
.
alphas
[
time_step
]
...
@@ -84,7 +77,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
...
@@ -84,7 +77,7 @@ class DDIMScheduler(nn.Module, ConfigMixin):
def
get_alpha_prod
(
self
,
time_step
):
def
get_alpha_prod
(
self
,
time_step
):
if
time_step
<
0
:
if
time_step
<
0
:
return
torch
.
tensor
(
1.0
)
return
self
.
one
return
self
.
alphas_cumprod
[
time_step
]
return
self
.
alphas_cumprod
[
time_step
]
def
get_orig_t
(
self
,
t
,
num_inference_steps
):
def
get_orig_t
(
self
,
t
,
num_inference_steps
):
...
@@ -128,28 +121,24 @@ class DDIMScheduler(nn.Module, ConfigMixin):
...
@@ -128,28 +121,24 @@ class DDIMScheduler(nn.Module, ConfigMixin):
# 3. compute predicted original image from predicted noise also called
# 3. compute predicted original image from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image
=
(
image
-
beta_prod_t
.
sqrt
(
)
*
residual
)
/
alpha_prod_t
.
sqrt
(
)
pred_original_image
=
(
image
-
beta_prod_t
**
(
0.5
)
*
residual
)
/
alpha_prod_t
**
(
0.5
)
# 4. Clip "predicted x_0"
# 4. Clip "predicted x_0"
if
self
.
clip_image
:
if
self
.
clip_image
:
pred_original_image
=
torch
.
cl
am
p
(
pred_original_image
,
-
1
,
1
)
pred_original_image
=
self
.
cl
i
p
(
pred_original_image
,
-
1
,
1
)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance
=
self
.
get_variance
(
t
,
num_inference_steps
)
variance
=
self
.
get_variance
(
t
,
num_inference_steps
)
std_dev_t
=
eta
*
variance
.
sqrt
(
)
std_dev_t
=
eta
*
variance
**
(
0.5
)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
)
.
sqrt
(
)
*
residual
pred_image_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
)
**
(
0.5
)
*
residual
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image
=
alpha_prod_t_prev
.
sqrt
(
)
*
pred_original_image
+
pred_image_direction
pred_prev_image
=
alpha_prod_t_prev
**
(
0.5
)
*
pred_original_image
+
pred_image_direction
return
pred_prev_image
return
pred_prev_image
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
# always sample on CPU to be deterministic
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
timesteps
return
self
.
timesteps
src/diffusers/schedulers/gaussian_ddpm.py
View file @
12b10cbe
...
@@ -13,19 +13,13 @@
...
@@ -13,19 +13,13 @@
# limitations under the License.
# limitations under the License.
import
math
import
math
import
torch
import
numpy
as
np
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
.schedulers_utils
import
betas_for_alpha_bar
,
linear_beta_schedule
from
.schedulers_utils
import
SchedulerMixin
,
betas_for_alpha_bar
,
linear_beta_schedule
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
class
GaussianDDPMScheduler
(
SchedulerMixin
,
ConfigMixin
):
class
GaussianDDPMScheduler
(
nn
.
Module
,
ConfigMixin
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
def
__init__
(
self
,
self
,
timesteps
=
1000
,
timesteps
=
1000
,
...
@@ -34,6 +28,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
...
@@ -34,6 +28,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
beta_schedule
=
"linear"
,
beta_schedule
=
"linear"
,
variance_type
=
"fixed_small"
,
variance_type
=
"fixed_small"
,
clip_predicted_image
=
True
,
clip_predicted_image
=
True
,
tensor_format
=
"np"
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
register
(
self
.
register
(
...
@@ -49,35 +44,38 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
...
@@ -49,35 +44,38 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
self
.
variance_type
=
variance_type
self
.
variance_type
=
variance_type
if
beta_schedule
==
"linear"
:
if
beta_schedule
==
"linear"
:
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
self
.
betas
=
linear_beta_schedule
(
timesteps
,
beta_start
=
beta_start
,
beta_end
=
beta_end
)
elif
beta_schedule
==
"squaredcos_cap_v2"
:
elif
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
# GLIDE cosine schedule
betas
=
betas_for_alpha_bar
(
self
.
betas
=
betas_for_alpha_bar
(
timesteps
,
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
)
else
:
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
alphas
=
1.0
-
betas
self
.
alphas
=
1.0
-
self
.
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
axis
=
0
)
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
self
.
one
=
np
.
array
(
1.0
)
self
.
register_buffer
(
"betas"
,
betas
.
to
(
torch
.
float32
))
self
.
register_buffer
(
"alphas"
,
alphas
.
to
(
torch
.
float32
))
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
register_buffer
(
"alphas_cumprod"
,
alphas_cumprod
.
to
(
torch
.
float32
))
# self.register_buffer("betas", betas.to(torch.float32))
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# self.register_buffer("alphas", alphas.to(torch.float32))
# TODO(PVP) - check how much of these is actually necessary!
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# TODO(PVP) - check how much of these is actually necessary!
# if variance_type == "fixed_small":
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# log_variance = torch.log(variance.clamp(min=1e-20))
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# elif variance_type == "fixed_large":
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
# if variance_type == "fixed_small":
#
# log_variance = torch.log(variance.clamp(min=1e-20))
#
# elif variance_type == "fixed_large":
# self.register_buffer("log_variance", log_variance.to(torch.float32))
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def
get_alpha
(
self
,
time_step
):
def
get_alpha
(
self
,
time_step
):
return
self
.
alphas
[
time_step
]
return
self
.
alphas
[
time_step
]
...
@@ -87,7 +85,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
...
@@ -87,7 +85,7 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
def
get_alpha_prod
(
self
,
time_step
):
def
get_alpha_prod
(
self
,
time_step
):
if
time_step
<
0
:
if
time_step
<
0
:
return
torch
.
tensor
(
1.0
)
return
self
.
one
return
self
.
alphas_cumprod
[
time_step
]
return
self
.
alphas_cumprod
[
time_step
]
def
get_variance
(
self
,
t
):
def
get_variance
(
self
,
t
):
...
@@ -97,11 +95,11 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
...
@@ -97,11 +95,11 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
# For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous image
# and sample from it to get previous image
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
variance
=
(
(
1
-
alpha_prod_t_prev
)
/
(
1
-
alpha_prod_t
)
*
self
.
get_beta
(
t
)
)
variance
=
(
1
-
alpha_prod_t_prev
)
/
(
1
-
alpha_prod_t
)
*
self
.
get_beta
(
t
)
# hacks - were probs added for training stability
# hacks - were probs added for training stability
if
self
.
variance_type
==
"fixed_small"
:
if
self
.
variance_type
==
"fixed_small"
:
variance
=
variance
.
clamp
(
min
=
1e-20
)
variance
=
self
.
clip
(
variance
,
min_value
=
1e-20
)
elif
self
.
variance_type
==
"fixed_large"
:
elif
self
.
variance_type
==
"fixed_large"
:
variance
=
self
.
get_beta
(
t
)
variance
=
self
.
get_beta
(
t
)
...
@@ -116,16 +114,16 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
...
@@ -116,16 +114,16 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
# 2. compute predicted original image from predicted noise also called
# 2. compute predicted original image from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image
=
(
image
-
beta_prod_t
.
sqrt
(
)
*
residual
)
/
alpha_prod_t
.
sqrt
(
)
pred_original_image
=
(
image
-
beta_prod_t
**
(
0.5
)
*
residual
)
/
alpha_prod_t
**
(
0.5
)
# 3. Clip "predicted x_0"
# 3. Clip "predicted x_0"
if
self
.
clip_predicted_image
:
if
self
.
clip_predicted_image
:
pred_original_image
=
torch
.
cl
am
p
(
pred_original_image
,
-
1
,
1
)
pred_original_image
=
self
.
cl
i
p
(
pred_original_image
,
-
1
,
1
)
# 4. Compute coefficients for pred_original_image x_0 and current image x_t
# 4. Compute coefficients for pred_original_image x_0 and current image x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image_coeff
=
(
alpha_prod_t_prev
.
sqrt
(
)
*
self
.
get_beta
(
t
))
/
beta_prod_t
pred_original_image_coeff
=
(
alpha_prod_t_prev
**
(
0.5
)
*
self
.
get_beta
(
t
))
/
beta_prod_t
current_image_coeff
=
self
.
get_alpha
(
t
)
.
sqrt
(
)
*
beta_prod_t_prev
/
beta_prod_t
current_image_coeff
=
self
.
get_alpha
(
t
)
**
(
0.5
)
*
beta_prod_t_prev
/
beta_prod_t
# 5. Compute predicted previous image µ_t
# 5. Compute predicted previous image µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
...
@@ -133,9 +131,5 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
...
@@ -133,9 +131,5 @@ class GaussianDDPMScheduler(nn.Module, ConfigMixin):
return
pred_prev_image
return
pred_prev_image
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
# always sample on CPU to be deterministic
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
timesteps
return
self
.
timesteps
src/diffusers/schedulers/schedulers_utils.py
View file @
12b10cbe
...
@@ -11,11 +11,15 @@
...
@@ -11,11 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
numpy
as
np
import
torch
import
torch
SCHEDULER_CONFIG_NAME
=
"scheduler_config.json"
def
linear_beta_schedule
(
timesteps
,
beta_start
,
beta_end
):
def
linear_beta_schedule
(
timesteps
,
beta_start
,
beta_end
):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float
64
)
return
np
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
np
.
float
32
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
...
@@ -35,4 +39,28 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
...
@@ -35,4 +39,28 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
t1
=
i
/
num_diffusion_timesteps
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
torch
.
tensor
(
betas
,
dtype
=
torch
.
float64
)
return
np
.
array
(
betas
,
dtype
=
np
.
float32
)
class
SchedulerMixin
:
config_name
=
SCHEDULER_CONFIG_NAME
def
set_format
(
self
,
tensor_format
=
"pt"
):
self
.
tensor_format
=
tensor_format
if
tensor_format
==
"pt"
:
for
key
,
value
in
vars
(
self
).
items
():
if
isinstance
(
value
,
np
.
ndarray
):
setattr
(
self
,
key
,
torch
.
from_numpy
(
value
))
return
self
def
clip
(
self
,
tensor
,
min_value
=
None
,
max_value
=
None
):
tensor_format
=
getattr
(
self
,
"tensor_format"
,
"pt"
)
if
tensor_format
==
"np"
:
return
np
.
clip
(
tensor
,
min_value
,
max_value
)
elif
tensor_format
==
"pt"
:
return
torch
.
clamp
(
tensor
,
min_value
,
max_value
)
raise
ValueError
(
f
"`self.tensor_format`:
{
self
.
tensor_format
}
is not valid."
)
src/diffusers/testing_utils.py
View file @
12b10cbe
import
os
import
os
import
random
import
random
import
unittest
import
unittest
import
torch
from
distutils.util
import
strtobool
from
distutils.util
import
strtobool
import
torch
global_rng
=
random
.
Random
()
global_rng
=
random
.
Random
()
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment