Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
bf13b76a
"vscode:/vscode.git/clone" did not exist on "fbf61f465b7756bd3d01a272ea994741c3cfcf8c"
Commit
bf13b76a
authored
Jun 13, 2022
by
anton-l
Browse files
Fix merge
parent
9c530191
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
158 additions
and
112 deletions
+158
-112
src/diffusers/__init__.py
src/diffusers/__init__.py
+1
-5
src/diffusers/schedulers/__init__.py
src/diffusers/schedulers/__init__.py
+1
-6
src/diffusers/schedulers/scheduling_plms.py
src/diffusers/schedulers/scheduling_plms.py
+156
-101
No files found.
src/diffusers/__init__.py
View file @
bf13b76a
...
@@ -10,9 +10,5 @@ from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
...
@@ -10,9 +10,5 @@ from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from
.models.unet_ldm
import
UNetLDMModel
from
.models.unet_ldm
import
UNetLDMModel
from
.pipeline_utils
import
DiffusionPipeline
from
.pipeline_utils
import
DiffusionPipeline
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
from
.pipelines
import
DDIM
,
DDPM
,
GLIDE
,
LatentDiffusion
from
.schedulers
import
DDIMScheduler
,
DDPMScheduler
,
SchedulerMixin
from
.schedulers
import
SchedulerMixin
,
DDIMScheduler
,
DDPMScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.schedulers.ddim
import
DDIMScheduler
from
.schedulers.gaussian_ddpm
import
GaussianDDPMScheduler
src/diffusers/schedulers/__init__.py
View file @
bf13b76a
...
@@ -16,12 +16,7 @@
...
@@ -16,12 +16,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddim
import
DDIMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_ddpm
import
DDPMScheduler
from
.scheduling_utils
import
SchedulerMixin
from
.scheduling_utils
import
SchedulerMixin
from
.classifier_free_guidance
import
ClassifierFreeGuidanceScheduler
from
.ddim
import
DDIMScheduler
from
.gaussian_ddpm
import
GaussianDDPMScheduler
from
.glide_ddim
import
GlideDDIMScheduler
from
.schedulers_utils
import
SchedulerMixin
src/diffusers/schedulers/scheduling_plms.py
View file @
bf13b76a
...
@@ -15,6 +15,7 @@ import math
...
@@ -15,6 +15,7 @@ import math
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
...
@@ -28,11 +29,11 @@ def noise_like(shape, device, repeat=False):
...
@@ -28,11 +29,11 @@ def noise_like(shape, device, repeat=False):
def
make_ddim_timesteps
(
ddim_discr_method
,
num_ddim_timesteps
,
num_ddpm_timesteps
,
verbose
=
True
):
def
make_ddim_timesteps
(
ddim_discr_method
,
num_ddim_timesteps
,
num_ddpm_timesteps
,
verbose
=
True
):
if
ddim_discr_method
==
'
uniform
'
:
if
ddim_discr_method
==
"
uniform
"
:
c
=
num_ddpm_timesteps
//
num_ddim_timesteps
c
=
num_ddpm_timesteps
//
num_ddim_timesteps
ddim_timesteps
=
np
.
asarray
(
list
(
range
(
0
,
num_ddpm_timesteps
,
c
)))
ddim_timesteps
=
np
.
asarray
(
list
(
range
(
0
,
num_ddpm_timesteps
,
c
)))
elif
ddim_discr_method
==
'
quad
'
:
elif
ddim_discr_method
==
"
quad
"
:
ddim_timesteps
=
((
np
.
linspace
(
0
,
np
.
sqrt
(
num_ddpm_timesteps
*
.
8
),
num_ddim_timesteps
))
**
2
).
astype
(
int
)
ddim_timesteps
=
((
np
.
linspace
(
0
,
np
.
sqrt
(
num_ddpm_timesteps
*
0
.8
),
num_ddim_timesteps
))
**
2
).
astype
(
int
)
else
:
else
:
raise
NotImplementedError
(
f
'There is no ddim discretization method called "
{
ddim_discr_method
}
"'
)
raise
NotImplementedError
(
f
'There is no ddim discretization method called "
{
ddim_discr_method
}
"'
)
...
@@ -40,7 +41,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
...
@@ -40,7 +41,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
# add one to get the final alpha values right (the ones from first scale to data during sampling)
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out
=
ddim_timesteps
+
1
steps_out
=
ddim_timesteps
+
1
if
verbose
:
if
verbose
:
print
(
f
'
Selected timesteps for ddim sampler:
{
steps_out
}
'
)
print
(
f
"
Selected timesteps for ddim sampler:
{
steps_out
}
"
)
return
steps_out
return
steps_out
...
@@ -52,9 +53,11 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
...
@@ -52,9 +53,11 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# according the the formula provided in https://arxiv.org/abs/2010.02502
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas
=
eta
*
np
.
sqrt
((
1
-
alphas_prev
)
/
(
1
-
alphas
)
*
(
1
-
alphas
/
alphas_prev
))
sigmas
=
eta
*
np
.
sqrt
((
1
-
alphas_prev
)
/
(
1
-
alphas
)
*
(
1
-
alphas
/
alphas_prev
))
if
verbose
:
if
verbose
:
print
(
f
'Selected alphas for ddim sampler: a_t:
{
alphas
}
; a_(t-1):
{
alphas_prev
}
'
)
print
(
f
"Selected alphas for ddim sampler: a_t:
{
alphas
}
; a_(t-1):
{
alphas_prev
}
"
)
print
(
f
'For the chosen value of eta, which is
{
eta
}
, '
print
(
f
'this results in the following sigma_t schedule for ddim sampler
{
sigmas
}
'
)
f
"For the chosen value of eta, which is
{
eta
}
, "
f
"this results in the following sigma_t schedule for ddim sampler
{
sigmas
}
"
)
return
sigmas
,
alphas
,
alphas_prev
return
sigmas
,
alphas
,
alphas_prev
...
@@ -71,41 +74,48 @@ class PLMSSampler(object):
...
@@ -71,41 +74,48 @@ class PLMSSampler(object):
attr
=
attr
.
to
(
torch
.
device
(
"cuda"
))
attr
=
attr
.
to
(
torch
.
device
(
"cuda"
))
setattr
(
self
,
name
,
attr
)
setattr
(
self
,
name
,
attr
)
def
make_schedule
(
self
,
ddim_num_steps
,
ddim_discretize
=
"uniform"
,
ddim_eta
=
0.
,
verbose
=
True
):
def
make_schedule
(
self
,
ddim_num_steps
,
ddim_discretize
=
"uniform"
,
ddim_eta
=
0.
0
,
verbose
=
True
):
if
ddim_eta
!=
0
:
if
ddim_eta
!=
0
:
raise
ValueError
(
'ddim_eta must be 0 for PLMS'
)
raise
ValueError
(
"ddim_eta must be 0 for PLMS"
)
self
.
ddim_timesteps
=
make_ddim_timesteps
(
ddim_discr_method
=
ddim_discretize
,
num_ddim_timesteps
=
ddim_num_steps
,
self
.
ddim_timesteps
=
make_ddim_timesteps
(
num_ddpm_timesteps
=
self
.
ddpm_num_timesteps
,
verbose
=
verbose
)
ddim_discr_method
=
ddim_discretize
,
num_ddim_timesteps
=
ddim_num_steps
,
num_ddpm_timesteps
=
self
.
ddpm_num_timesteps
,
verbose
=
verbose
,
)
alphas_cumprod
=
self
.
model
.
alphas_cumprod
alphas_cumprod
=
self
.
model
.
alphas_cumprod
assert
alphas_cumprod
.
shape
[
0
]
==
self
.
ddpm_num_timesteps
,
'
alphas have to be defined for each timestep
'
assert
alphas_cumprod
.
shape
[
0
]
==
self
.
ddpm_num_timesteps
,
"
alphas have to be defined for each timestep
"
to_torch
=
lambda
x
:
x
.
clone
().
detach
().
to
(
torch
.
float32
).
to
(
self
.
model
.
device
)
to_torch
=
lambda
x
:
x
.
clone
().
detach
().
to
(
torch
.
float32
).
to
(
self
.
model
.
device
)
self
.
register_buffer
(
'
betas
'
,
to_torch
(
self
.
model
.
betas
))
self
.
register_buffer
(
"
betas
"
,
to_torch
(
self
.
model
.
betas
))
self
.
register_buffer
(
'
alphas_cumprod
'
,
to_torch
(
alphas_cumprod
))
self
.
register_buffer
(
"
alphas_cumprod
"
,
to_torch
(
alphas_cumprod
))
self
.
register_buffer
(
'
alphas_cumprod_prev
'
,
to_torch
(
self
.
model
.
alphas_cumprod_prev
))
self
.
register_buffer
(
"
alphas_cumprod_prev
"
,
to_torch
(
self
.
model
.
alphas_cumprod_prev
))
# calculations for diffusion q(x_t | x_{t-1}) and others
# calculations for diffusion q(x_t | x_{t-1}) and others
self
.
register_buffer
(
'
sqrt_alphas_cumprod
'
,
to_torch
(
np
.
sqrt
(
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
"
sqrt_alphas_cumprod
"
,
to_torch
(
np
.
sqrt
(
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
'
sqrt_one_minus_alphas_cumprod
'
,
to_torch
(
np
.
sqrt
(
1.
-
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
"
sqrt_one_minus_alphas_cumprod
"
,
to_torch
(
np
.
sqrt
(
1.
0
-
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
'
log_one_minus_alphas_cumprod
'
,
to_torch
(
np
.
log
(
1.
-
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
"
log_one_minus_alphas_cumprod
"
,
to_torch
(
np
.
log
(
1.
0
-
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
'
sqrt_recip_alphas_cumprod
'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
"
sqrt_recip_alphas_cumprod
"
,
to_torch
(
np
.
sqrt
(
1.
0
/
alphas_cumprod
.
cpu
())))
self
.
register_buffer
(
'
sqrt_recipm1_alphas_cumprod
'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
.
cpu
()
-
1
)))
self
.
register_buffer
(
"
sqrt_recipm1_alphas_cumprod
"
,
to_torch
(
np
.
sqrt
(
1.
0
/
alphas_cumprod
.
cpu
()
-
1
)))
# ddim sampling parameters
# ddim sampling parameters
ddim_sigmas
,
ddim_alphas
,
ddim_alphas_prev
=
make_ddim_sampling_parameters
(
alphacums
=
alphas_cumprod
.
cpu
(),
ddim_sigmas
,
ddim_alphas
,
ddim_alphas_prev
=
make_ddim_sampling_parameters
(
ddim_timesteps
=
self
.
ddim_timesteps
,
alphacums
=
alphas_cumprod
.
cpu
(),
ddim_timesteps
=
self
.
ddim_timesteps
,
eta
=
ddim_eta
,
verbose
=
verbose
eta
=
ddim_eta
,
verbose
=
verbose
)
)
self
.
register_buffer
(
'
ddim_sigmas
'
,
ddim_sigmas
)
self
.
register_buffer
(
"
ddim_sigmas
"
,
ddim_sigmas
)
self
.
register_buffer
(
'
ddim_alphas
'
,
ddim_alphas
)
self
.
register_buffer
(
"
ddim_alphas
"
,
ddim_alphas
)
self
.
register_buffer
(
'
ddim_alphas_prev
'
,
ddim_alphas_prev
)
self
.
register_buffer
(
"
ddim_alphas_prev
"
,
ddim_alphas_prev
)
self
.
register_buffer
(
'
ddim_sqrt_one_minus_alphas
'
,
np
.
sqrt
(
1.
-
ddim_alphas
))
self
.
register_buffer
(
"
ddim_sqrt_one_minus_alphas
"
,
np
.
sqrt
(
1.
0
-
ddim_alphas
))
sigmas_for_original_sampling_steps
=
ddim_eta
*
torch
.
sqrt
(
sigmas_for_original_sampling_steps
=
ddim_eta
*
torch
.
sqrt
(
(
1
-
self
.
alphas_cumprod_prev
)
/
(
1
-
self
.
alphas_cumprod
)
*
(
(
1
-
self
.
alphas_cumprod_prev
)
1
-
self
.
alphas_cumprod
/
self
.
alphas_cumprod_prev
))
/
(
1
-
self
.
alphas_cumprod
)
self
.
register_buffer
(
'ddim_sigmas_for_original_num_steps'
,
sigmas_for_original_sampling_steps
)
*
(
1
-
self
.
alphas_cumprod
/
self
.
alphas_cumprod_prev
)
)
self
.
register_buffer
(
"ddim_sigmas_for_original_num_steps"
,
sigmas_for_original_sampling_steps
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
sample
(
self
,
def
sample
(
self
,
S
,
S
,
batch_size
,
batch_size
,
shape
,
shape
,
...
@@ -114,20 +124,20 @@ class PLMSSampler(object):
...
@@ -114,20 +124,20 @@ class PLMSSampler(object):
normals_sequence
=
None
,
normals_sequence
=
None
,
img_callback
=
None
,
img_callback
=
None
,
quantize_x0
=
False
,
quantize_x0
=
False
,
eta
=
0.
,
eta
=
0.
0
,
mask
=
None
,
mask
=
None
,
x0
=
None
,
x0
=
None
,
temperature
=
1.
,
temperature
=
1.
0
,
noise_dropout
=
0.
,
noise_dropout
=
0.
0
,
score_corrector
=
None
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
corrector_kwargs
=
None
,
verbose
=
True
,
verbose
=
True
,
x_T
=
None
,
x_T
=
None
,
log_every_t
=
100
,
log_every_t
=
100
,
unconditional_guidance_scale
=
1.
,
unconditional_guidance_scale
=
1.
0
,
unconditional_conditioning
=
None
,
unconditional_conditioning
=
None
,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**
kwargs
**
kwargs
,
):
):
if
conditioning
is
not
None
:
if
conditioning
is
not
None
:
if
isinstance
(
conditioning
,
dict
):
if
isinstance
(
conditioning
,
dict
):
...
@@ -142,13 +152,16 @@ class PLMSSampler(object):
...
@@ -142,13 +152,16 @@ class PLMSSampler(object):
# sampling
# sampling
C
,
H
,
W
=
shape
C
,
H
,
W
=
shape
size
=
(
batch_size
,
C
,
H
,
W
)
size
=
(
batch_size
,
C
,
H
,
W
)
print
(
f
'
Data shape for PLMS sampling is
{
size
}
'
)
print
(
f
"
Data shape for PLMS sampling is
{
size
}
"
)
samples
,
intermediates
=
self
.
plms_sampling
(
conditioning
,
size
,
samples
,
intermediates
=
self
.
plms_sampling
(
conditioning
,
size
,
callback
=
callback
,
callback
=
callback
,
img_callback
=
img_callback
,
img_callback
=
img_callback
,
quantize_denoised
=
quantize_x0
,
quantize_denoised
=
quantize_x0
,
mask
=
mask
,
x0
=
x0
,
mask
=
mask
,
x0
=
x0
,
ddim_use_original_steps
=
False
,
ddim_use_original_steps
=
False
,
noise_dropout
=
noise_dropout
,
noise_dropout
=
noise_dropout
,
temperature
=
temperature
,
temperature
=
temperature
,
...
@@ -162,12 +175,26 @@ class PLMSSampler(object):
...
@@ -162,12 +175,26 @@ class PLMSSampler(object):
return
samples
,
intermediates
return
samples
,
intermediates
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
plms_sampling
(
self
,
cond
,
shape
,
def
plms_sampling
(
x_T
=
None
,
ddim_use_original_steps
=
False
,
self
,
callback
=
None
,
timesteps
=
None
,
quantize_denoised
=
False
,
cond
,
mask
=
None
,
x0
=
None
,
img_callback
=
None
,
log_every_t
=
100
,
shape
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
x_T
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,):
ddim_use_original_steps
=
False
,
callback
=
None
,
timesteps
=
None
,
quantize_denoised
=
False
,
mask
=
None
,
x0
=
None
,
img_callback
=
None
,
log_every_t
=
100
,
temperature
=
1.0
,
noise_dropout
=
0.0
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.0
,
unconditional_conditioning
=
None
,
):
device
=
self
.
model
.
betas
.
device
device
=
self
.
model
.
betas
.
device
b
=
shape
[
0
]
b
=
shape
[
0
]
if
x_T
is
None
:
if
x_T
is
None
:
...
@@ -181,12 +208,12 @@ class PLMSSampler(object):
...
@@ -181,12 +208,12 @@ class PLMSSampler(object):
subset_end
=
int
(
min
(
timesteps
/
self
.
ddim_timesteps
.
shape
[
0
],
1
)
*
self
.
ddim_timesteps
.
shape
[
0
])
-
1
subset_end
=
int
(
min
(
timesteps
/
self
.
ddim_timesteps
.
shape
[
0
],
1
)
*
self
.
ddim_timesteps
.
shape
[
0
])
-
1
timesteps
=
self
.
ddim_timesteps
[:
subset_end
]
timesteps
=
self
.
ddim_timesteps
[:
subset_end
]
intermediates
=
{
'
x_inter
'
:
[
img
],
'
pred_x0
'
:
[
img
]}
intermediates
=
{
"
x_inter
"
:
[
img
],
"
pred_x0
"
:
[
img
]}
time_range
=
list
(
reversed
(
range
(
0
,
timesteps
)))
if
ddim_use_original_steps
else
np
.
flip
(
timesteps
)
time_range
=
list
(
reversed
(
range
(
0
,
timesteps
)))
if
ddim_use_original_steps
else
np
.
flip
(
timesteps
)
total_steps
=
timesteps
if
ddim_use_original_steps
else
timesteps
.
shape
[
0
]
total_steps
=
timesteps
if
ddim_use_original_steps
else
timesteps
.
shape
[
0
]
print
(
f
"Running PLMS Sampling with
{
total_steps
}
timesteps"
)
print
(
f
"Running PLMS Sampling with
{
total_steps
}
timesteps"
)
iterator
=
tqdm
(
time_range
,
desc
=
'
PLMS Sampler
'
,
total
=
total_steps
)
iterator
=
tqdm
(
time_range
,
desc
=
"
PLMS Sampler
"
,
total
=
total_steps
)
old_eps
=
[]
old_eps
=
[]
for
i
,
step
in
enumerate
(
iterator
):
for
i
,
step
in
enumerate
(
iterator
):
...
@@ -197,36 +224,62 @@ class PLMSSampler(object):
...
@@ -197,36 +224,62 @@ class PLMSSampler(object):
if
mask
is
not
None
:
if
mask
is
not
None
:
assert
x0
is
not
None
assert
x0
is
not
None
img_orig
=
self
.
model
.
q_sample
(
x0
,
ts
)
# TODO: deterministic forward pass?
img_orig
=
self
.
model
.
q_sample
(
x0
,
ts
)
# TODO: deterministic forward pass?
img
=
img_orig
*
mask
+
(
1.
-
mask
)
*
img
img
=
img_orig
*
mask
+
(
1.0
-
mask
)
*
img
outs
=
self
.
p_sample_plms
(
img
,
cond
,
ts
,
index
=
index
,
use_original_steps
=
ddim_use_original_steps
,
outs
=
self
.
p_sample_plms
(
quantize_denoised
=
quantize_denoised
,
temperature
=
temperature
,
img
,
noise_dropout
=
noise_dropout
,
score_corrector
=
score_corrector
,
cond
,
ts
,
index
=
index
,
use_original_steps
=
ddim_use_original_steps
,
quantize_denoised
=
quantize_denoised
,
temperature
=
temperature
,
noise_dropout
=
noise_dropout
,
score_corrector
=
score_corrector
,
corrector_kwargs
=
corrector_kwargs
,
corrector_kwargs
=
corrector_kwargs
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
unconditional_conditioning
=
unconditional_conditioning
,
old_eps
=
old_eps
,
t_next
=
ts_next
)
old_eps
=
old_eps
,
t_next
=
ts_next
,
)
img
,
pred_x0
,
e_t
=
outs
img
,
pred_x0
,
e_t
=
outs
old_eps
.
append
(
e_t
)
old_eps
.
append
(
e_t
)
if
len
(
old_eps
)
>=
4
:
if
len
(
old_eps
)
>=
4
:
old_eps
.
pop
(
0
)
old_eps
.
pop
(
0
)
if
callback
:
callback
(
i
)
if
callback
:
if
img_callback
:
img_callback
(
pred_x0
,
i
)
callback
(
i
)
if
img_callback
:
img_callback
(
pred_x0
,
i
)
if
index
%
log_every_t
==
0
or
index
==
total_steps
-
1
:
if
index
%
log_every_t
==
0
or
index
==
total_steps
-
1
:
intermediates
[
'
x_inter
'
].
append
(
img
)
intermediates
[
"
x_inter
"
].
append
(
img
)
intermediates
[
'
pred_x0
'
].
append
(
pred_x0
)
intermediates
[
"
pred_x0
"
].
append
(
pred_x0
)
return
img
,
intermediates
return
img
,
intermediates
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
p_sample_plms
(
self
,
x
,
c
,
t
,
index
,
repeat_noise
=
False
,
use_original_steps
=
False
,
quantize_denoised
=
False
,
def
p_sample_plms
(
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
self
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
old_eps
=
None
,
t_next
=
None
):
x
,
c
,
t
,
index
,
repeat_noise
=
False
,
use_original_steps
=
False
,
quantize_denoised
=
False
,
temperature
=
1.0
,
noise_dropout
=
0.0
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.0
,
unconditional_conditioning
=
None
,
old_eps
=
None
,
t_next
=
None
,
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
def
get_model_output
(
x
,
t
):
def
get_model_output
(
x
,
t
):
if
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.
:
if
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.
0
:
e_t
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
e_t
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
else
:
else
:
x_in
=
torch
.
cat
([
x
]
*
2
)
x_in
=
torch
.
cat
([
x
]
*
2
)
...
@@ -243,7 +296,9 @@ class PLMSSampler(object):
...
@@ -243,7 +296,9 @@ class PLMSSampler(object):
alphas
=
self
.
model
.
alphas_cumprod
if
use_original_steps
else
self
.
ddim_alphas
alphas
=
self
.
model
.
alphas_cumprod
if
use_original_steps
else
self
.
ddim_alphas
alphas_prev
=
self
.
model
.
alphas_cumprod_prev
if
use_original_steps
else
self
.
ddim_alphas_prev
alphas_prev
=
self
.
model
.
alphas_cumprod_prev
if
use_original_steps
else
self
.
ddim_alphas_prev
sqrt_one_minus_alphas
=
self
.
model
.
sqrt_one_minus_alphas_cumprod
if
use_original_steps
else
self
.
ddim_sqrt_one_minus_alphas
sqrt_one_minus_alphas
=
(
self
.
model
.
sqrt_one_minus_alphas_cumprod
if
use_original_steps
else
self
.
ddim_sqrt_one_minus_alphas
)
sigmas
=
self
.
model
.
ddim_sigmas_for_original_num_steps
if
use_original_steps
else
self
.
ddim_sigmas
sigmas
=
self
.
model
.
ddim_sigmas_for_original_num_steps
if
use_original_steps
else
self
.
ddim_sigmas
def
get_x_prev_and_pred_x0
(
e_t
,
index
):
def
get_x_prev_and_pred_x0
(
e_t
,
index
):
...
@@ -251,16 +306,16 @@ class PLMSSampler(object):
...
@@ -251,16 +306,16 @@ class PLMSSampler(object):
a_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas
[
index
],
device
=
device
)
a_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas
[
index
],
device
=
device
)
a_prev
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas_prev
[
index
],
device
=
device
)
a_prev
=
torch
.
full
((
b
,
1
,
1
,
1
),
alphas_prev
[
index
],
device
=
device
)
sigma_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
sigmas
[
index
],
device
=
device
)
sigma_t
=
torch
.
full
((
b
,
1
,
1
,
1
),
sigmas
[
index
],
device
=
device
)
sqrt_one_minus_at
=
torch
.
full
((
b
,
1
,
1
,
1
),
sqrt_one_minus_alphas
[
index
],
device
=
device
)
sqrt_one_minus_at
=
torch
.
full
((
b
,
1
,
1
,
1
),
sqrt_one_minus_alphas
[
index
],
device
=
device
)
# current prediction for x_0
# current prediction for x_0
pred_x0
=
(
x
-
sqrt_one_minus_at
*
e_t
)
/
a_t
.
sqrt
()
pred_x0
=
(
x
-
sqrt_one_minus_at
*
e_t
)
/
a_t
.
sqrt
()
if
quantize_denoised
:
if
quantize_denoised
:
pred_x0
,
_
,
*
_
=
self
.
model
.
first_stage_model
.
quantize
(
pred_x0
)
pred_x0
,
_
,
*
_
=
self
.
model
.
first_stage_model
.
quantize
(
pred_x0
)
# direction pointing to x_t
# direction pointing to x_t
dir_xt
=
(
1.
-
a_prev
-
sigma_t
**
2
).
sqrt
()
*
e_t
dir_xt
=
(
1.
0
-
a_prev
-
sigma_t
**
2
).
sqrt
()
*
e_t
noise
=
sigma_t
*
noise_like
(
x
.
shape
,
device
,
repeat_noise
)
*
temperature
noise
=
sigma_t
*
noise_like
(
x
.
shape
,
device
,
repeat_noise
)
*
temperature
if
noise_dropout
>
0.
:
if
noise_dropout
>
0.
0
:
noise
=
torch
.
nn
.
functional
.
dropout
(
noise
,
p
=
noise_dropout
)
noise
=
torch
.
nn
.
functional
.
dropout
(
noise
,
p
=
noise_dropout
)
x_prev
=
a_prev
.
sqrt
()
*
pred_x0
+
dir_xt
+
noise
x_prev
=
a_prev
.
sqrt
()
*
pred_x0
+
dir_xt
+
noise
return
x_prev
,
pred_x0
return
x_prev
,
pred_x0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment