Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
bff9746d
Commit
bff9746d
authored
Jun 13, 2022
by
anton-l
Browse files
GLIDE + DDIM without artifacts
parent
2f8e556b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
26 deletions
+18
-26
src/diffusers/pipelines/pipeline_glide.py
src/diffusers/pipelines/pipeline_glide.py
+4
-17
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+14
-9
No files found.
src/diffusers/pipelines/pipeline_glide.py
View file @
bff9746d
...
@@ -859,9 +859,6 @@ class GLIDE(DiffusionPipeline):
...
@@ -859,9 +859,6 @@ class GLIDE(DiffusionPipeline):
nonzero_mask
=
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
# no noise when t == 0
nonzero_mask
=
(
t
!=
0
).
float
().
view
(
-
1
,
*
([
1
]
*
(
len
(
image
.
shape
)
-
1
)))
# no noise when t == 0
image
=
mean
+
nonzero_mask
*
torch
.
exp
(
0.5
*
log_variance
)
*
noise
image
=
mean
+
nonzero_mask
*
torch
.
exp
(
0.5
*
log_variance
)
*
noise
image
=
image
[:
1
].
permute
(
0
,
2
,
3
,
1
)
return
image
# 4. Run the upscaling step
# 4. Run the upscaling step
batch_size
=
1
batch_size
=
1
image
=
image
[:
1
]
image
=
image
[:
1
]
...
@@ -879,20 +876,10 @@ class GLIDE(DiffusionPipeline):
...
@@ -879,20 +876,10 @@ class GLIDE(DiffusionPipeline):
)
)
image
=
image
.
to
(
torch_device
)
*
upsample_temp
image
=
image
.
to
(
torch_device
)
*
upsample_temp
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
num_trained_timesteps
=
self
.
upscale_noise_scheduler
.
timesteps
num_trained_timesteps
=
self
.
upscale_noise_scheduler
.
timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps_upscale
)
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps_upscale
)
self
.
upscale_noise_scheduler
.
rescale_betas
(
num_inference_steps_upscale
)
# adapt the beta schedule to the number of steps
# self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale)
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps_upscale
)),
total
=
num_inference_steps_upscale
):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps_upscale
)),
total
=
num_inference_steps_upscale
):
# 1. predict noise residual
# 1. predict noise residual
...
@@ -903,7 +890,7 @@ class GLIDE(DiffusionPipeline):
...
@@ -903,7 +890,7 @@ class GLIDE(DiffusionPipeline):
# 2. predict previous mean of image x_t-1
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
upscale_noise_scheduler
.
step
(
pred_prev_image
=
self
.
upscale_noise_scheduler
.
step
(
noise_residual
,
image
,
t
,
num_inference_steps_upscale
,
eta
noise_residual
,
image
,
t
,
num_inference_steps_upscale
,
eta
,
use_clipped_residual
=
True
)
)
# 3. optionally sample variance
# 3. optionally sample variance
...
@@ -917,6 +904,6 @@ class GLIDE(DiffusionPipeline):
...
@@ -917,6 +904,6 @@ class GLIDE(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
image
=
pred_prev_image
+
variance
image
=
image
.
permute
(
0
,
2
,
3
,
1
)
image
=
image
.
clamp
(
-
1
,
1
).
permute
(
0
,
2
,
3
,
1
)
return
image
return
image
src/diffusers/schedulers/scheduling_ddim.py
View file @
bff9746d
...
@@ -69,14 +69,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -69,14 +69,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def
rescale_betas
(
self
,
num_timesteps
):
# def rescale_betas(self, num_timesteps):
if
self
.
beta_schedule
==
"linear"
:
# # GLIDE scaling
scale
=
self
.
timesteps
/
num_timesteps
# if self.beta_schedule == "linear":
self
.
betas
=
linear_beta_schedule
(
# scale = self.timesteps / num_timesteps
num_timesteps
,
beta_start
=
self
.
beta_start
*
scale
,
beta_end
=
self
.
beta_end
*
scale
# self.betas = linear_beta_schedule(
)
# num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
self
.
alphas
=
1.0
-
self
.
betas
# )
self
.
alphas_cumprod
=
np
.
cumprod
(
self
.
alphas
,
axis
=
0
)
# self.alphas = 1.0 - self.betas
# self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
def
get_alpha
(
self
,
time_step
):
def
get_alpha
(
self
,
time_step
):
return
self
.
alphas
[
time_step
]
return
self
.
alphas
[
time_step
]
...
@@ -107,7 +108,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -107,7 +108,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return
variance
return
variance
def
step
(
self
,
residual
,
image
,
t
,
num_inference_steps
,
eta
):
def
step
(
self
,
residual
,
image
,
t
,
num_inference_steps
,
eta
,
use_clipped_residual
=
False
):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Ideally, read DDIM paper in-detail understanding
...
@@ -141,6 +142,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -141,6 +142,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
variance
=
self
.
get_variance
(
t
,
num_inference_steps
)
variance
=
self
.
get_variance
(
t
,
num_inference_steps
)
std_dev_t
=
eta
*
variance
**
(
0.5
)
std_dev_t
=
eta
*
variance
**
(
0.5
)
if
use_clipped_residual
:
# the residual is always re-derived from the clipped x_0 in GLIDE
residual
=
(
image
-
alpha_prod_t
**
(
0.5
)
*
pred_original_image
)
/
beta_prod_t
**
(
0.5
)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
)
**
(
0.5
)
*
residual
pred_image_direction
=
(
1
-
alpha_prod_t_prev
-
std_dev_t
**
2
)
**
(
0.5
)
*
residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment