Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
efa773af
Unverified
Commit
efa773af
authored
Aug 29, 2022
by
Anton Lozhkov
Committed by
GitHub
Aug 29, 2022
Browse files
Support K-LMS in img2img (#270)
* Support K-LMS in img2img * Apply review suggestions
parent
da7d4cf2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
8 deletions
+32
-8
examples/inference/image_to_image.py
examples/inference/image_to_image.py
+28
-6
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
...s/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+1
-0
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+3
-2
No files found.
examples/inference/image_to_image.py
View file @
efa773af
...
...
@@ -5,7 +5,14 @@ import numpy as np
import
torch
import
PIL
from
diffusers
import
AutoencoderKL
,
DDIMScheduler
,
DiffusionPipeline
,
PNDMScheduler
,
UNet2DConditionModel
from
diffusers
import
(
AutoencoderKL
,
DDIMScheduler
,
DiffusionPipeline
,
LMSDiscreteScheduler
,
PNDMScheduler
,
UNet2DConditionModel
,
)
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionSafetyChecker
from
tqdm.auto
import
tqdm
from
transformers
import
CLIPFeatureExtractor
,
CLIPTextModel
,
CLIPTokenizer
...
...
@@ -87,12 +94,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# get the original timestep using init_timestep
init_timestep
=
int
(
num_inference_steps
*
strength
)
+
offset
init_timestep
=
min
(
init_timestep
,
num_inference_steps
)
timesteps
=
self
.
scheduler
.
timesteps
[
-
init_timestep
]
timesteps
=
torch
.
tensor
([
timesteps
]
*
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
if
isinstance
(
self
.
scheduler
,
LMSDiscreteScheduler
):
timesteps
=
torch
.
tensor
(
[
num_inference_steps
-
init_timestep
]
*
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
else
:
timesteps
=
self
.
scheduler
.
timesteps
[
-
init_timestep
]
timesteps
=
torch
.
tensor
([
timesteps
]
*
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# add noise to latents using the timesteps
noise
=
torch
.
randn
(
init_latents
.
shape
,
generator
=
generator
,
device
=
self
.
device
)
init_latents
=
self
.
scheduler
.
add_noise
(
init_latents
,
noise
,
timesteps
)
init_latents
=
self
.
scheduler
.
add_noise
(
init_latents
,
noise
,
timesteps
)
.
to
(
self
.
device
)
# get prompt text embeddings
text_input
=
self
.
tokenizer
(
...
...
@@ -133,8 +145,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
latents
=
init_latents
t_start
=
max
(
num_inference_steps
-
init_timestep
+
offset
,
0
)
for
i
,
t
in
tqdm
(
enumerate
(
self
.
scheduler
.
timesteps
[
t_start
:])):
t_index
=
t_start
+
i
# expand the latents if we are doing classifier free guidance
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
if
isinstance
(
self
.
scheduler
,
LMSDiscreteScheduler
):
sigma
=
self
.
scheduler
.
sigmas
[
t_index
]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input
=
latent_model_input
/
((
sigma
**
2
+
1
)
**
0.5
)
latent_model_input
=
latent_model_input
.
to
(
self
.
unet
.
dtype
)
t
=
t
.
to
(
self
.
unet
.
dtype
)
# predict the noise residual
noise_pred
=
self
.
unet
(
latent_model_input
,
t
,
encoder_hidden_states
=
text_embeddings
)[
"sample"
]
...
...
@@ -145,11 +164,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
noise_pred
=
noise_pred_uncond
+
guidance_scale
*
(
noise_pred_text
-
noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
)[
"prev_sample"
]
if
isinstance
(
self
.
scheduler
,
LMSDiscreteScheduler
):
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t_index
,
latents
,
**
extra_step_kwargs
)[
"prev_sample"
]
else
:
latents
=
self
.
scheduler
.
step
(
noise_pred
,
t
,
latents
,
**
extra_step_kwargs
)[
"prev_sample"
]
# scale and decode the image latents with vae
latents
=
1
/
0.18215
*
latents
image
=
self
.
vae
.
decode
(
latents
)
image
=
self
.
vae
.
decode
(
latents
.
to
(
self
.
vae
.
dtype
)
)
image
=
(
image
/
2
+
0.5
).
clamp
(
0
,
1
)
image
=
image
.
cpu
().
permute
(
0
,
2
,
3
,
1
).
numpy
()
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
View file @
efa773af
...
...
@@ -138,6 +138,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latent_model_input
=
torch
.
cat
([
latents
]
*
2
)
if
do_classifier_free_guidance
else
latents
if
isinstance
(
self
.
scheduler
,
LMSDiscreteScheduler
):
sigma
=
self
.
scheduler
.
sigmas
[
i
]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input
=
latent_model_input
/
((
sigma
**
2
+
1
)
**
0.5
)
# predict the noise residual
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
efa773af
...
...
@@ -124,8 +124,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return
{
"prev_sample"
:
prev_sample
}
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sigmas
=
self
.
match_shape
(
self
.
sigmas
,
noise
)
noisy_samples
=
original_samples
+
noise
*
sigmas
[
timesteps
]
sigmas
=
self
.
match_shape
(
self
.
sigmas
[
timesteps
],
noise
)
noisy_samples
=
original_samples
+
noise
*
sigmas
return
noisy_samples
def
__len__
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment