Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9e31c6a7
Commit
9e31c6a7
authored
Jun 21, 2022
by
anton-l
Browse files
refactor GLIDE text2im pipeline, remove classifier_free_guidance
parent
072d7519
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
54 additions
and
208 deletions
+54
-208
examples/train_unconditional.py
examples/train_unconditional.py
+1
-1
src/diffusers/__init__.py
src/diffusers/__init__.py
+0
-1
src/diffusers/optimization.py
src/diffusers/optimization.py
+1
-1
src/diffusers/pipeline_utils.py
src/diffusers/pipeline_utils.py
+0
-1
src/diffusers/pipelines/pipeline_glide.py
src/diffusers/pipelines/pipeline_glide.py
+42
-103
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+0
-1
src/diffusers/schedulers/classifier_free_guidance.py
src/diffusers/schedulers/classifier_free_guidance.py
+0
-96
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+10
-4
No files found.
examples/train_unconditional.py
View file @
9e31c6a7
...
@@ -10,6 +10,7 @@ from datasets import load_dataset
...
@@ -10,6 +10,7 @@ from datasets import load_dataset
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers
import
DDPM
,
DDPMScheduler
,
UNetModel
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.hub_utils
import
init_git_repo
,
push_to_hub
from
diffusers.modeling_utils
import
unwrap_model
from
diffusers.modeling_utils
import
unwrap_model
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
logging
from
diffusers.utils
import
logging
from
torchvision.transforms
import
(
from
torchvision.transforms
import
(
CenterCrop
,
CenterCrop
,
...
@@ -21,7 +22,6 @@ from torchvision.transforms import (
...
@@ -21,7 +22,6 @@ from torchvision.transforms import (
ToTensor
,
ToTensor
,
)
)
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
diffusers.optimization
import
get_scheduler
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
...
src/diffusers/__init__.py
View file @
9e31c6a7
...
@@ -13,7 +13,6 @@ from .models.unet_rl import TemporalUNet
...
@@ -13,7 +13,6 @@ from .models.unet_rl import TemporalUNet
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
BDDM
,
DDIM
,
DDPM
,
PNDM
from
.pipelines
import
BDDM
,
DDIM
,
DDPM
,
PNDM
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
GradTTSScheduler
,
PNDMScheduler
,
SchedulerMixin
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
if
is_transformers_available
():
if
is_transformers_available
():
...
...
src/diffusers/optimization.py
View file @
9e31c6a7
src/diffusers/pipeline_utils.py
View file @
9e31c6a7
...
@@ -36,7 +36,6 @@ LOADABLE_CLASSES = {
...
@@ -36,7 +36,6 @@ LOADABLE_CLASSES = {
"ModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ModelMixin"
:
[
"save_pretrained"
,
"from_pretrained"
],
"SchedulerMixin"
:
[
"save_config"
,
"from_config"
],
"SchedulerMixin"
:
[
"save_config"
,
"from_config"
],
"DiffusionPipeline"
:
[
"save_pretrained"
,
"from_pretrained"
],
"DiffusionPipeline"
:
[
"save_pretrained"
,
"from_pretrained"
],
"ClassifierFreeGuidanceScheduler"
:
[
"save_config"
,
"from_config"
],
},
},
"transformers"
:
{
"transformers"
:
{
"PreTrainedTokenizer"
:
[
"save_pretrained"
,
"from_pretrained"
],
"PreTrainedTokenizer"
:
[
"save_pretrained"
,
"from_pretrained"
],
...
...
src/diffusers/pipelines/pipeline_glide.py
View file @
9e31c6a7
...
@@ -32,7 +32,7 @@ from transformers.utils import ModelOutput, add_start_docstrings_to_model_forwar
...
@@ -32,7 +32,7 @@ from transformers.utils import ModelOutput, add_start_docstrings_to_model_forwar
from
..models
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
..models
import
GLIDESuperResUNetModel
,
GLIDETextToImageUNetModel
from
..pipeline_utils
import
DiffusionPipeline
from
..pipeline_utils
import
DiffusionPipeline
from
..schedulers
import
ClassifierFreeGuidance
Scheduler
,
DD
I
MScheduler
from
..schedulers
import
DDIM
Scheduler
,
DD
P
MScheduler
from
..utils
import
logging
from
..utils
import
logging
...
@@ -715,7 +715,7 @@ class GLIDE(DiffusionPipeline):
...
@@ -715,7 +715,7 @@ class GLIDE(DiffusionPipeline):
def
__init__
(
def
__init__
(
self
,
self
,
text_unet
:
GLIDETextToImageUNetModel
,
text_unet
:
GLIDETextToImageUNetModel
,
text_noise_scheduler
:
ClassifierFreeGuidance
Scheduler
,
text_noise_scheduler
:
DDPM
Scheduler
,
text_encoder
:
CLIPTextModel
,
text_encoder
:
CLIPTextModel
,
tokenizer
:
GPT2Tokenizer
,
tokenizer
:
GPT2Tokenizer
,
upscale_unet
:
GLIDESuperResUNetModel
,
upscale_unet
:
GLIDESuperResUNetModel
,
...
@@ -731,100 +731,28 @@ class GLIDE(DiffusionPipeline):
...
@@ -731,100 +731,28 @@ class GLIDE(DiffusionPipeline):
upscale_noise_scheduler
=
upscale_noise_scheduler
,
upscale_noise_scheduler
=
upscale_noise_scheduler
,
)
)
def
q_posterior_mean_variance
(
self
,
scheduler
,
x_start
,
x_t
,
t
):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert
x_start
.
shape
==
x_t
.
shape
posterior_mean
=
(
_extract_into_tensor
(
scheduler
.
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
x_start
+
_extract_into_tensor
(
scheduler
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
)
posterior_variance
=
_extract_into_tensor
(
scheduler
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_log_variance_clipped
=
_extract_into_tensor
(
scheduler
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
assert
(
posterior_mean
.
shape
[
0
]
==
posterior_variance
.
shape
[
0
]
==
posterior_log_variance_clipped
.
shape
[
0
]
==
x_start
.
shape
[
0
]
)
return
posterior_mean
,
posterior_variance
,
posterior_log_variance_clipped
def
p_mean_variance
(
self
,
model
,
scheduler
,
x
,
t
,
transformer_out
=
None
,
low_res
=
None
,
clip_denoised
=
True
):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
B
,
C
=
x
.
shape
[:
2
]
assert
t
.
shape
==
(
B
,)
if
transformer_out
is
None
:
# super-res model
model_output
=
model
(
x
,
t
,
low_res
)
else
:
# text2image model
model_output
=
model
(
x
,
t
,
transformer_out
)
assert
model_output
.
shape
==
(
B
,
C
*
2
,
*
x
.
shape
[
2
:])
model_output
,
model_var_values
=
torch
.
split
(
model_output
,
C
,
dim
=
1
)
min_log
=
_extract_into_tensor
(
scheduler
.
posterior_log_variance_clipped
,
t
,
x
.
shape
)
max_log
=
_extract_into_tensor
(
np
.
log
(
scheduler
.
betas
),
t
,
x
.
shape
)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac
=
(
model_var_values
+
1
)
/
2
model_log_variance
=
frac
*
max_log
+
(
1
-
frac
)
*
min_log
model_variance
=
torch
.
exp
(
model_log_variance
)
pred_xstart
=
self
.
_predict_xstart_from_eps
(
scheduler
,
x_t
=
x
,
t
=
t
,
eps
=
model_output
)
if
clip_denoised
:
pred_xstart
=
pred_xstart
.
clamp
(
-
1
,
1
)
model_mean
,
_
,
_
=
self
.
q_posterior_mean_variance
(
scheduler
,
x_start
=
pred_xstart
,
x_t
=
x
,
t
=
t
)
assert
model_mean
.
shape
==
model_log_variance
.
shape
==
pred_xstart
.
shape
==
x
.
shape
return
model_mean
,
model_variance
,
model_log_variance
,
pred_xstart
def
_predict_xstart_from_eps
(
self
,
scheduler
,
x_t
,
t
,
eps
):
assert
x_t
.
shape
==
eps
.
shape
return
(
_extract_into_tensor
(
scheduler
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
_extract_into_tensor
(
scheduler
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
eps
)
def
_predict_eps_from_xstart
(
self
,
scheduler
,
x_t
,
t
,
pred_xstart
):
return
(
_extract_into_tensor
(
scheduler
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
pred_xstart
)
/
_extract_into_tensor
(
scheduler
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps_upscale
=
50
):
def
__call__
(
self
,
prompt
,
generator
=
None
,
torch_device
=
None
,
num_inference_steps_upscale
=
50
,
guidance_scale
=
3.0
,
eta
=
0.0
,
upsample_temp
=
0.997
,
):
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
self
.
text_unet
.
to
(
torch_device
)
self
.
text_unet
.
to
(
torch_device
)
self
.
text_encoder
.
to
(
torch_device
)
self
.
text_encoder
.
to
(
torch_device
)
self
.
upscale_unet
.
to
(
torch_device
)
self
.
upscale_unet
.
to
(
torch_device
)
# Create a classifier-free guidance sampling function
def
text_model_fn
(
x_t
,
timesteps
,
transformer_out
,
**
kwargs
):
guidance_scale
=
3.0
def
text_model_fn
(
x_t
,
ts
,
transformer_out
,
**
kwargs
):
half
=
x_t
[:
len
(
x_t
)
//
2
]
half
=
x_t
[:
len
(
x_t
)
//
2
]
combined
=
torch
.
cat
([
half
,
half
],
dim
=
0
)
combined
=
torch
.
cat
([
half
,
half
],
dim
=
0
)
model_out
=
self
.
text_unet
(
combined
,
ts
,
transformer_out
,
**
kwargs
)
model_out
=
self
.
text_unet
(
combined
,
t
imestep
s
,
transformer_out
,
**
kwargs
)
eps
,
rest
=
model_out
[:,
:
3
],
model_out
[:,
3
:]
eps
,
rest
=
model_out
[:,
:
3
],
model_out
[:,
3
:]
cond_eps
,
uncond_eps
=
torch
.
split
(
eps
,
len
(
eps
)
//
2
,
dim
=
0
)
cond_eps
,
uncond_eps
=
torch
.
split
(
eps
,
len
(
eps
)
//
2
,
dim
=
0
)
half_eps
=
uncond_eps
+
guidance_scale
*
(
cond_eps
-
uncond_eps
)
half_eps
=
uncond_eps
+
guidance_scale
*
(
cond_eps
-
uncond_eps
)
...
@@ -833,7 +761,15 @@ class GLIDE(DiffusionPipeline):
...
@@ -833,7 +761,15 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise
# 1. Sample gaussian noise
batch_size
=
2
# second image is empty for classifier-free guidance
batch_size
=
2
# second image is empty for classifier-free guidance
image
=
torch
.
randn
((
batch_size
,
self
.
text_unet
.
in_channels
,
64
,
64
),
generator
=
generator
).
to
(
torch_device
)
image
=
torch
.
randn
(
(
batch_size
,
self
.
text_unet
.
in_channels
,
self
.
text_unet
.
resolution
,
self
.
text_unet
.
resolution
,
),
generator
=
generator
,
).
to
(
torch_device
)
# 2. Encode tokens
# 2. Encode tokens
# an empty input is needed to guide the model away from it
# an empty input is needed to guide the model away from it
...
@@ -843,25 +779,30 @@ class GLIDE(DiffusionPipeline):
...
@@ -843,25 +779,30 @@ class GLIDE(DiffusionPipeline):
transformer_out
=
self
.
text_encoder
(
input_ids
,
attention_mask
).
last_hidden_state
transformer_out
=
self
.
text_encoder
(
input_ids
,
attention_mask
).
last_hidden_state
# 3. Run the text2image generation step
# 3. Run the text2image generation step
num_timesteps
=
len
(
self
.
text_noise_scheduler
)
num_prediction_steps
=
len
(
self
.
text_noise_scheduler
)
for
i
in
tqdm
.
tqdm
(
reversed
(
range
(
num_timesteps
)),
total
=
num_timesteps
):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_prediction_steps
)),
total
=
num_prediction_steps
):
t
=
torch
.
tensor
([
i
]
*
image
.
shape
[
0
],
device
=
torch_device
)
with
torch
.
no_grad
():
mean
,
variance
,
log_variance
,
pred_xstart
=
self
.
p_mean_variance
(
time_input
=
torch
.
tensor
([
t
]
*
image
.
shape
[
0
],
device
=
torch_device
)
text_model_fn
,
self
.
text_noise_scheduler
,
image
,
t
,
transformer_out
=
transformer_out
model_output
=
text_model_fn
(
image
,
time_input
,
transformer_out
)
)
noise_residual
,
model_var_values
=
torch
.
split
(
model_output
,
3
,
dim
=
1
)
min_log
=
self
.
text_noise_scheduler
.
get_variance
(
t
,
"fixed_small_log"
)
max_log
=
self
.
text_noise_scheduler
.
get_variance
(
t
,
"fixed_large_log"
)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac
=
(
model_var_values
+
1
)
/
2
model_log_variance
=
frac
*
max_log
+
(
1
-
frac
)
*
min_log
pred_prev_image
=
self
.
text_noise_scheduler
.
step
(
noise_residual
,
image
,
t
)
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
torch_device
)
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
torch_device
)
nonzero_mask
=
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
# no noise when t == 0
variance
=
torch
.
exp
(
0.5
*
model_log_variance
)
*
noise
image
=
mean
+
nonzero_mask
*
torch
.
exp
(
0.5
*
log_variance
)
*
noise
# set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
# 4. Run the upscaling step
# 4. Run the upscaling step
batch_size
=
1
batch_size
=
1
image
=
image
[:
1
]
image
=
image
[:
1
]
low_res
=
((
image
+
1
)
*
127.5
).
round
()
/
127.5
-
1
low_res
=
((
image
+
1
)
*
127.5
).
round
()
/
127.5
-
1
eta
=
0.0
# Tune this parameter to control the sharpness of 256x256 images.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp
=
0.997
# Sample gaussian noise to begin loop
# Sample gaussian noise to begin loop
image
=
torch
.
randn
(
image
=
torch
.
randn
(
...
@@ -877,8 +818,6 @@ class GLIDE(DiffusionPipeline):
...
@@ -877,8 +818,6 @@ class GLIDE(DiffusionPipeline):
num_trained_timesteps
=
self
.
upscale_noise_scheduler
.
timesteps
num_trained_timesteps
=
self
.
upscale_noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps_upscale
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps_upscale
)
# adapt the beta schedule to the number of steps
# self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps_upscale
)),
total
=
num_inference_steps_upscale
):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps_upscale
)),
total
=
num_inference_steps_upscale
):
# 1. predict noise residual
# 1. predict noise residual
...
...
src/diffusers/schedulers/__init__.py
View file @
9e31c6a7
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
from
.scheduling_grad_tts
import
GradTTSScheduler
...
...
src/diffusers/schedulers/classifier_free_guidance.py
deleted
100644 → 0
View file @
072d7519
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
import
torch
from
torch
import
nn
from
..configuration_utils
import
ConfigMixin
SAMPLING_CONFIG_NAME
=
"scheduler_config.json"
def
linear_beta_schedule
(
timesteps
,
beta_start
,
beta_end
):
return
torch
.
linspace
(
beta_start
,
beta_end
,
timesteps
,
dtype
=
torch
.
float64
)
def
betas_for_alpha_bar
(
num_diffusion_timesteps
,
alpha_bar
,
max_beta
=
0.999
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas
=
[]
for
i
in
range
(
num_diffusion_timesteps
):
t1
=
i
/
num_diffusion_timesteps
t2
=
(
i
+
1
)
/
num_diffusion_timesteps
betas
.
append
(
min
(
1
-
alpha_bar
(
t2
)
/
alpha_bar
(
t1
),
max_beta
))
return
np
.
array
(
betas
,
dtype
=
np
.
float64
)
class
ClassifierFreeGuidanceScheduler
(
nn
.
Module
,
ConfigMixin
):
config_name
=
SAMPLING_CONFIG_NAME
def
__init__
(
self
,
timesteps
=
1000
,
beta_schedule
=
"squaredcos_cap_v2"
,
):
super
().
__init__
()
self
.
register_to_config
(
timesteps
=
timesteps
,
beta_schedule
=
beta_schedule
,
)
if
beta_schedule
==
"squaredcos_cap_v2"
:
# GLIDE cosine schedule
self
.
betas
=
betas_for_alpha_bar
(
timesteps
,
lambda
t
:
math
.
cos
((
t
+
0.008
)
/
1.008
*
math
.
pi
/
2
)
**
2
,
)
else
:
raise
NotImplementedError
(
f
"
{
beta_schedule
}
does is not implemented for
{
self
.
__class__
}
"
)
alphas
=
1.0
-
self
.
betas
self
.
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
self
.
alphas_cumprod_prev
=
np
.
append
(
1.0
,
self
.
alphas_cumprod
[:
-
1
])
# calculations for diffusion q(x_t | x_{t-1}) and others
self
.
sqrt_recip_alphas_cumprod
=
np
.
sqrt
(
1.0
/
self
.
alphas_cumprod
)
self
.
sqrt_recipm1_alphas_cumprod
=
np
.
sqrt
(
1.0
/
self
.
alphas_cumprod
-
1
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self
.
posterior_variance
=
self
.
betas
*
(
1.0
-
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self
.
posterior_log_variance_clipped
=
np
.
log
(
np
.
append
(
self
.
posterior_variance
[
1
],
self
.
posterior_variance
[
1
:])
)
self
.
posterior_mean_coef1
=
self
.
betas
*
np
.
sqrt
(
self
.
alphas_cumprod_prev
)
/
(
1.0
-
self
.
alphas_cumprod
)
self
.
posterior_mean_coef2
=
(
1.0
-
self
.
alphas_cumprod_prev
)
*
np
.
sqrt
(
alphas
)
/
(
1.0
-
self
.
alphas_cumprod
)
def
sample_noise
(
self
,
shape
,
device
,
generator
=
None
):
# always sample on CPU to be deterministic
return
torch
.
randn
(
shape
,
generator
=
generator
).
to
(
device
)
def
__len__
(
self
):
return
self
.
config
.
timesteps
src/diffusers/schedulers/scheduling_ddpm.py
View file @
9e31c6a7
...
@@ -87,7 +87,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -87,7 +87,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self
.
set_format
(
tensor_format
=
tensor_format
)
self
.
set_format
(
tensor_format
=
tensor_format
)
def
get_variance
(
self
,
t
):
def
get_variance
(
self
,
t
,
variance_type
=
None
):
alpha_prod_t
=
self
.
alphas_cumprod
[
t
]
alpha_prod_t
=
self
.
alphas_cumprod
[
t
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
t
-
1
]
if
t
>
0
else
self
.
one
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
t
-
1
]
if
t
>
0
else
self
.
one
...
@@ -96,14 +96,20 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -96,14 +96,20 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample
variance
=
(
1
-
alpha_prod_t_prev
)
/
(
1
-
alpha_prod_t
)
*
self
.
betas
[
t
]
variance
=
(
1
-
alpha_prod_t_prev
)
/
(
1
-
alpha_prod_t
)
*
self
.
betas
[
t
]
if
variance_type
is
None
:
variance_type
=
self
.
config
.
variance_type
# hacks - were probs added for training stability
# hacks - were probs added for training stability
if
self
.
config
.
variance_type
==
"fixed_small"
:
if
variance_type
==
"fixed_small"
:
variance
=
self
.
clip
(
variance
,
min_value
=
1e-20
)
variance
=
self
.
clip
(
variance
,
min_value
=
1e-20
)
# for rl-diffuser https://arxiv.org/abs/2205.09991
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif
self
.
config
.
variance_type
==
"fixed_small_log"
:
elif
variance_type
==
"fixed_small_log"
:
variance
=
self
.
log
(
self
.
clip
(
variance
,
min_value
=
1e-20
))
variance
=
self
.
log
(
self
.
clip
(
variance
,
min_value
=
1e-20
))
elif
self
.
config
.
variance_type
==
"fixed_large"
:
elif
variance_type
==
"fixed_large"
:
variance
=
self
.
betas
[
t
]
variance
=
self
.
betas
[
t
]
elif
variance_type
==
"fixed_large_log"
:
# GLIDE max_log
variance
=
self
.
log
(
self
.
betas
[
t
])
return
variance
return
variance
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment