Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
485b8bb0
Unverified
Commit
485b8bb0
authored
Sep 09, 2024
by
YiYi Xu
Committed by
GitHub
Sep 09, 2024
Browse files
refactor `get_timesteps` for SDXL img2img + add set_begin_index (#9375)
* refator + add begin_index * add kolors img2img to doc
parent
d08ad658
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
86 additions
and
60 deletions
+86
-60
docs/source/en/api/pipelines/kolors.md
docs/source/en/api/pipelines/kolors.md
+8
-0
src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
...pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
+13
-10
src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
+13
-10
src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
+13
-10
src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
+13
-10
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
...able_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+13
-10
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
...able_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+13
-10
No files found.
docs/source/en/api/pipelines/kolors.md
View file @
485b8bb0
...
@@ -105,3 +105,11 @@ image.save("kolors_ipa_sample.png")
...
@@ -105,3 +105,11 @@ image.save("kolors_ipa_sample.png")
-
all
-
all
-
__call__
-
__call__
## KolorsImg2ImgPipeline
[[autodoc]] KolorsImg2ImgPipeline
-
all
-
__call__
src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
View file @
485b8bb0
...
@@ -1024,14 +1024,16 @@ class StableDiffusionXLControlNetInpaintPipeline(
...
@@ -1024,14 +1024,16 @@ class StableDiffusionXLControlNetInpaintPipeline(
if
denoising_start
is
None
:
if
denoising_start
is
None
:
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
else
:
t_start
=
0
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
# Strength is irrelevant if we directly request a timestep to start at;
else
:
# that is, strength is determined by the denoising_start instead.
# Strength is irrelevant if we directly request a timestep to start at;
if
denoising_start
i
s
not
None
:
# that is, strength is determined by the
denoising_start i
nstead.
discrete_timestep_cutoff
=
int
(
discrete_timestep_cutoff
=
int
(
round
(
round
(
self
.
scheduler
.
config
.
num_train_timesteps
self
.
scheduler
.
config
.
num_train_timesteps
...
@@ -1039,7 +1041,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
...
@@ -1039,7 +1041,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
)
)
)
)
num_inference_steps
=
(
timesteps
<
discrete_timestep_cutoff
).
sum
().
item
()
num_inference_steps
=
(
self
.
scheduler
.
timesteps
<
discrete_timestep_cutoff
).
sum
().
item
()
if
self
.
scheduler
.
order
==
2
and
num_inference_steps
%
2
==
0
:
if
self
.
scheduler
.
order
==
2
and
num_inference_steps
%
2
==
0
:
# if the scheduler is a 2nd order scheduler we might have to do +1
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# because `num_inference_steps` might be even given that every timestep
...
@@ -1050,11 +1052,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
...
@@ -1050,11 +1052,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
num_inference_steps
=
num_inference_steps
+
1
num_inference_steps
=
num_inference_steps
+
1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps
=
timesteps
[
-
num_inference_steps
:]
t_start
=
len
(
self
.
scheduler
.
timesteps
)
-
num_inference_steps
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
)
return
timesteps
,
num_inference_steps
return
timesteps
,
num_inference_steps
return
timesteps
,
num_inference_steps
-
t_start
def
_get_add_time_ids
(
def
_get_add_time_ids
(
self
,
self
,
original_size
,
original_size
,
...
...
src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
View file @
485b8bb0
...
@@ -564,14 +564,16 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
...
@@ -564,14 +564,16 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
if
denoising_start
is
None
:
if
denoising_start
is
None
:
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
else
:
t_start
=
0
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
# Strength is irrelevant if we directly request a timestep to start at;
else
:
# that is, strength is determined by the denoising_start instead.
# Strength is irrelevant if we directly request a timestep to start at;
if
denoising_start
i
s
not
None
:
# that is, strength is determined by the
denoising_start i
nstead.
discrete_timestep_cutoff
=
int
(
discrete_timestep_cutoff
=
int
(
round
(
round
(
self
.
scheduler
.
config
.
num_train_timesteps
self
.
scheduler
.
config
.
num_train_timesteps
...
@@ -579,7 +581,7 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
...
@@ -579,7 +581,7 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
)
)
)
)
num_inference_steps
=
(
timesteps
<
discrete_timestep_cutoff
).
sum
().
item
()
num_inference_steps
=
(
self
.
scheduler
.
timesteps
<
discrete_timestep_cutoff
).
sum
().
item
()
if
self
.
scheduler
.
order
==
2
and
num_inference_steps
%
2
==
0
:
if
self
.
scheduler
.
order
==
2
and
num_inference_steps
%
2
==
0
:
# if the scheduler is a 2nd order scheduler we might have to do +1
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# because `num_inference_steps` might be even given that every timestep
...
@@ -590,11 +592,12 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
...
@@ -590,11 +592,12 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
num_inference_steps
=
num_inference_steps
+
1
num_inference_steps
=
num_inference_steps
+
1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps
=
timesteps
[
-
num_inference_steps
:]
t_start
=
len
(
self
.
scheduler
.
timesteps
)
-
num_inference_steps
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
)
return
timesteps
,
num_inference_steps
return
timesteps
,
num_inference_steps
return
timesteps
,
num_inference_steps
-
t_start
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
def
prepare_latents
(
def
prepare_latents
(
self
,
image
,
timestep
,
batch_size
,
num_images_per_prompt
,
dtype
,
device
,
generator
=
None
,
add_noise
=
True
self
,
image
,
timestep
,
batch_size
,
num_images_per_prompt
,
dtype
,
device
,
generator
=
None
,
add_noise
=
True
...
...
src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
View file @
485b8bb0
...
@@ -648,14 +648,16 @@ class StableDiffusionXLPAGImg2ImgPipeline(
...
@@ -648,14 +648,16 @@ class StableDiffusionXLPAGImg2ImgPipeline(
if
denoising_start
is
None
:
if
denoising_start
is
None
:
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
else
:
t_start
=
0
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
# Strength is irrelevant if we directly request a timestep to start at;
else
:
# that is, strength is determined by the denoising_start instead.
# Strength is irrelevant if we directly request a timestep to start at;
if
denoising_start
i
s
not
None
:
# that is, strength is determined by the
denoising_start i
nstead.
discrete_timestep_cutoff
=
int
(
discrete_timestep_cutoff
=
int
(
round
(
round
(
self
.
scheduler
.
config
.
num_train_timesteps
self
.
scheduler
.
config
.
num_train_timesteps
...
@@ -663,7 +665,7 @@ class StableDiffusionXLPAGImg2ImgPipeline(
...
@@ -663,7 +665,7 @@ class StableDiffusionXLPAGImg2ImgPipeline(
)
)
)
)
num_inference_steps
=
(
timesteps
<
discrete_timestep_cutoff
).
sum
().
item
()
num_inference_steps
=
(
self
.
scheduler
.
timesteps
<
discrete_timestep_cutoff
).
sum
().
item
()
if
self
.
scheduler
.
order
==
2
and
num_inference_steps
%
2
==
0
:
if
self
.
scheduler
.
order
==
2
and
num_inference_steps
%
2
==
0
:
# if the scheduler is a 2nd order scheduler we might have to do +1
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# because `num_inference_steps` might be even given that every timestep
...
@@ -674,11 +676,12 @@ class StableDiffusionXLPAGImg2ImgPipeline(
...
@@ -674,11 +676,12 @@ class StableDiffusionXLPAGImg2ImgPipeline(
num_inference_steps
=
num_inference_steps
+
1
num_inference_steps
=
num_inference_steps
+
1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps
=
timesteps
[
-
num_inference_steps
:]
t_start
=
len
(
self
.
scheduler
.
timesteps
)
-
num_inference_steps
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
)
return
timesteps
,
num_inference_steps
return
timesteps
,
num_inference_steps
return
timesteps
,
num_inference_steps
-
t_start
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
def
prepare_latents
(
def
prepare_latents
(
self
,
image
,
timestep
,
batch_size
,
num_images_per_prompt
,
dtype
,
device
,
generator
=
None
,
add_noise
=
True
self
,
image
,
timestep
,
batch_size
,
num_images_per_prompt
,
dtype
,
device
,
generator
=
None
,
add_noise
=
True
...
...
src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
View file @
485b8bb0
...
@@ -897,14 +897,16 @@ class StableDiffusionXLPAGInpaintPipeline(
...
@@ -897,14 +897,16 @@ class StableDiffusionXLPAGInpaintPipeline(
if
denoising_start
is
None
:
if
denoising_start
is
None
:
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
else
:
t_start
=
0
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
# Strength is irrelevant if we directly request a timestep to start at;
else
:
# that is, strength is determined by the denoising_start instead.
# Strength is irrelevant if we directly request a timestep to start at;
if
denoising_start
i
s
not
None
:
# that is, strength is determined by the
denoising_start i
nstead.
discrete_timestep_cutoff
=
int
(
discrete_timestep_cutoff
=
int
(
round
(
round
(
self
.
scheduler
.
config
.
num_train_timesteps
self
.
scheduler
.
config
.
num_train_timesteps
...
@@ -912,7 +914,7 @@ class StableDiffusionXLPAGInpaintPipeline(
...
@@ -912,7 +914,7 @@ class StableDiffusionXLPAGInpaintPipeline(
)
)
)
)
num_inference_steps
=
(
timesteps
<
discrete_timestep_cutoff
).
sum
().
item
()
num_inference_steps
=
(
self
.
scheduler
.
timesteps
<
discrete_timestep_cutoff
).
sum
().
item
()
if
self
.
scheduler
.
order
==
2
and
num_inference_steps
%
2
==
0
:
if
self
.
scheduler
.
order
==
2
and
num_inference_steps
%
2
==
0
:
# if the scheduler is a 2nd order scheduler we might have to do +1
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# because `num_inference_steps` might be even given that every timestep
...
@@ -923,11 +925,12 @@ class StableDiffusionXLPAGInpaintPipeline(
...
@@ -923,11 +925,12 @@ class StableDiffusionXLPAGInpaintPipeline(
num_inference_steps
=
num_inference_steps
+
1
num_inference_steps
=
num_inference_steps
+
1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps
=
timesteps
[
-
num_inference_steps
:]
t_start
=
len
(
self
.
scheduler
.
timesteps
)
-
num_inference_steps
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
)
return
timesteps
,
num_inference_steps
return
timesteps
,
num_inference_steps
return
timesteps
,
num_inference_steps
-
t_start
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
def
_get_add_time_ids
(
def
_get_add_time_ids
(
self
,
self
,
...
...
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
View file @
485b8bb0
...
@@ -640,14 +640,16 @@ class StableDiffusionXLImg2ImgPipeline(
...
@@ -640,14 +640,16 @@ class StableDiffusionXLImg2ImgPipeline(
if
denoising_start
is
None
:
if
denoising_start
is
None
:
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
else
:
t_start
=
0
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
# Strength is irrelevant if we directly request a timestep to start at;
else
:
# that is, strength is determined by the denoising_start instead.
# Strength is irrelevant if we directly request a timestep to start at;
if
denoising_start
i
s
not
None
:
# that is, strength is determined by the
denoising_start i
nstead.
discrete_timestep_cutoff
=
int
(
discrete_timestep_cutoff
=
int
(
round
(
round
(
self
.
scheduler
.
config
.
num_train_timesteps
self
.
scheduler
.
config
.
num_train_timesteps
...
@@ -655,7 +657,7 @@ class StableDiffusionXLImg2ImgPipeline(
...
@@ -655,7 +657,7 @@ class StableDiffusionXLImg2ImgPipeline(
)
)
)
)
num_inference_steps
=
(
timesteps
<
discrete_timestep_cutoff
).
sum
().
item
()
num_inference_steps
=
(
self
.
scheduler
.
timesteps
<
discrete_timestep_cutoff
).
sum
().
item
()
if
self
.
scheduler
.
order
==
2
and
num_inference_steps
%
2
==
0
:
if
self
.
scheduler
.
order
==
2
and
num_inference_steps
%
2
==
0
:
# if the scheduler is a 2nd order scheduler we might have to do +1
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# because `num_inference_steps` might be even given that every timestep
...
@@ -666,11 +668,12 @@ class StableDiffusionXLImg2ImgPipeline(
...
@@ -666,11 +668,12 @@ class StableDiffusionXLImg2ImgPipeline(
num_inference_steps
=
num_inference_steps
+
1
num_inference_steps
=
num_inference_steps
+
1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps
=
timesteps
[
-
num_inference_steps
:]
t_start
=
len
(
self
.
scheduler
.
timesteps
)
-
num_inference_steps
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
)
return
timesteps
,
num_inference_steps
return
timesteps
,
num_inference_steps
return
timesteps
,
num_inference_steps
-
t_start
def
prepare_latents
(
def
prepare_latents
(
self
,
image
,
timestep
,
batch_size
,
num_images_per_prompt
,
dtype
,
device
,
generator
=
None
,
add_noise
=
True
self
,
image
,
timestep
,
batch_size
,
num_images_per_prompt
,
dtype
,
device
,
generator
=
None
,
add_noise
=
True
):
):
...
...
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
View file @
485b8bb0
...
@@ -901,14 +901,16 @@ class StableDiffusionXLInpaintPipeline(
...
@@ -901,14 +901,16 @@ class StableDiffusionXLInpaintPipeline(
if
denoising_start
is
None
:
if
denoising_start
is
None
:
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
init_timestep
=
min
(
int
(
num_inference_steps
*
strength
),
num_inference_steps
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
else
:
t_start
=
0
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
# Strength is irrelevant if we directly request a timestep to start at;
else
:
# that is, strength is determined by the denoising_start instead.
# Strength is irrelevant if we directly request a timestep to start at;
if
denoising_start
i
s
not
None
:
# that is, strength is determined by the
denoising_start i
nstead.
discrete_timestep_cutoff
=
int
(
discrete_timestep_cutoff
=
int
(
round
(
round
(
self
.
scheduler
.
config
.
num_train_timesteps
self
.
scheduler
.
config
.
num_train_timesteps
...
@@ -916,7 +918,7 @@ class StableDiffusionXLInpaintPipeline(
...
@@ -916,7 +918,7 @@ class StableDiffusionXLInpaintPipeline(
)
)
)
)
num_inference_steps
=
(
timesteps
<
discrete_timestep_cutoff
).
sum
().
item
()
num_inference_steps
=
(
self
.
scheduler
.
timesteps
<
discrete_timestep_cutoff
).
sum
().
item
()
if
self
.
scheduler
.
order
==
2
and
num_inference_steps
%
2
==
0
:
if
self
.
scheduler
.
order
==
2
and
num_inference_steps
%
2
==
0
:
# if the scheduler is a 2nd order scheduler we might have to do +1
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# because `num_inference_steps` might be even given that every timestep
...
@@ -927,11 +929,12 @@ class StableDiffusionXLInpaintPipeline(
...
@@ -927,11 +929,12 @@ class StableDiffusionXLInpaintPipeline(
num_inference_steps
=
num_inference_steps
+
1
num_inference_steps
=
num_inference_steps
+
1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps
=
timesteps
[
-
num_inference_steps
:]
t_start
=
len
(
self
.
scheduler
.
timesteps
)
-
num_inference_steps
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
)
return
timesteps
,
num_inference_steps
return
timesteps
,
num_inference_steps
return
timesteps
,
num_inference_steps
-
t_start
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
def
_get_add_time_ids
(
def
_get_add_time_ids
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment