Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a2d424eb
Unverified
Commit
a2d424eb
authored
Dec 04, 2024
by
hlky
Committed by
GitHub
Dec 04, 2024
Browse files
Add `sigmas` to pipelines using FlowMatch (#10116)
parent
25ddc794
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
44 additions
and
58 deletions
+44
-58
src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
+1
-8
src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
.../controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
+6
-6
src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
..._sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
+6
-6
src/diffusers/pipelines/lumina/pipeline_lumina.py
src/diffusers/pipelines/lumina/pipeline_lumina.py
+1
-8
src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
+6
-6
src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
+6
-6
src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
...pelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+6
-6
src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
...stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
+6
-6
src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
...stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
+6
-6
No files found.
src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
View file @
a2d424eb
...
...
@@ -387,7 +387,6 @@ class AuraFlowPipeline(DiffusionPipeline):
prompt
:
Union
[
str
,
List
[
str
]]
=
None
,
negative_prompt
:
Union
[
str
,
List
[
str
]]
=
None
,
num_inference_steps
:
int
=
50
,
timesteps
:
List
[
int
]
=
None
,
sigmas
:
List
[
float
]
=
None
,
guidance_scale
:
float
=
3.5
,
num_images_per_prompt
:
Optional
[
int
]
=
1
,
...
...
@@ -424,10 +423,6 @@ class AuraFlowPipeline(DiffusionPipeline):
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
...
...
@@ -522,9 +517,7 @@ class AuraFlowPipeline(DiffusionPipeline):
# 4. Prepare timesteps
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
timesteps
,
sigmas
)
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
sigmas
=
sigmas
)
# 5. Prepare latents.
latent_channels
=
self
.
transformer
.
config
.
in_channels
...
...
src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
View file @
a2d424eb
...
...
@@ -733,7 +733,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
num_inference_steps
:
int
=
28
,
timesteps
:
List
[
int
]
=
None
,
sigmas
:
Optional
[
List
[
float
]
]
=
None
,
guidance_scale
:
float
=
7.0
,
control_guidance_start
:
Union
[
float
,
List
[
float
]]
=
0.0
,
control_guidance_end
:
Union
[
float
,
List
[
float
]]
=
1.0
,
...
...
@@ -778,10 +778,10 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timestep
s (`List[
in
t]`, *optional*):
Custom
timestep
s to use for the denoising process with schedulers which support a `
timestep
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
Must be in descending order.
sigma
s (`List[
floa
t]`, *optional*):
Custom
sigma
s to use for the denoising process with schedulers which support a `
sigma
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
...
...
@@ -998,7 +998,7 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
assert
False
# 4. Prepare timesteps
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
timestep
s
)
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
sigmas
=
sigma
s
)
num_warmup_steps
=
max
(
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
,
0
)
self
.
_num_timesteps
=
len
(
timesteps
)
...
...
src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
View file @
a2d424eb
...
...
@@ -787,7 +787,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
num_inference_steps
:
int
=
28
,
timesteps
:
List
[
int
]
=
None
,
sigmas
:
Optional
[
List
[
float
]
]
=
None
,
guidance_scale
:
float
=
7.0
,
control_guidance_start
:
Union
[
float
,
List
[
float
]]
=
0.0
,
control_guidance_end
:
Union
[
float
,
List
[
float
]]
=
1.0
,
...
...
@@ -833,10 +833,10 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timestep
s (`List[
in
t]`, *optional*):
Custom
timestep
s to use for the denoising process with schedulers which support a `
timestep
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
Must be in descending order.
sigma
s (`List[
floa
t]`, *optional*):
Custom
sigma
s to use for the denoising process with schedulers which support a `
sigma
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
...
...
@@ -1033,7 +1033,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
controlnet_pooled_projections
=
controlnet_pooled_projections
or
pooled_prompt_embeds
# 4. Prepare timesteps
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
timestep
s
)
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
sigmas
=
sigma
s
)
num_warmup_steps
=
max
(
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
,
0
)
self
.
_num_timesteps
=
len
(
timesteps
)
...
...
src/diffusers/pipelines/lumina/pipeline_lumina.py
View file @
a2d424eb
...
...
@@ -617,7 +617,6 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
width
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
None
,
num_inference_steps
:
int
=
30
,
timesteps
:
List
[
int
]
=
None
,
guidance_scale
:
float
=
4.0
,
negative_prompt
:
Union
[
str
,
List
[
str
]]
=
None
,
sigmas
:
List
[
float
]
=
None
,
...
...
@@ -649,10 +648,6 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
num_inference_steps (`int`, *optional*, defaults to 30):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
...
...
@@ -776,9 +771,7 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
prompt_attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
negative_prompt_attention_mask
],
dim
=
0
)
# 4. Prepare timesteps
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
timesteps
,
sigmas
)
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
sigmas
=
sigmas
)
# 5. Prepare latents.
latent_channels
=
self
.
transformer
.
config
.
in_channels
...
...
src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
View file @
a2d424eb
...
...
@@ -693,7 +693,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
num_inference_steps
:
int
=
28
,
timesteps
:
List
[
int
]
=
None
,
sigmas
:
Optional
[
List
[
float
]
]
=
None
,
guidance_scale
:
float
=
7.0
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
negative_prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
...
...
@@ -735,10 +735,10 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timestep
s (`List[
in
t]`, *optional*):
Custom
timestep
s to use for the denoising process with schedulers which support a `
timestep
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
Must be in descending order.
sigma
s (`List[
floa
t]`, *optional*):
Custom
sigma
s to use for the denoising process with schedulers which support a `
sigma
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
...
...
@@ -890,7 +890,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
pooled_prompt_embeds
=
torch
.
cat
([
negative_pooled_prompt_embeds
,
pooled_prompt_embeds
],
dim
=
0
)
# 4. Prepare timesteps
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
timestep
s
)
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
sigmas
=
sigma
s
)
num_warmup_steps
=
max
(
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
,
0
)
self
.
_num_timesteps
=
len
(
timesteps
)
...
...
src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
View file @
a2d424eb
...
...
@@ -733,7 +733,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
image
:
PipelineImageInput
=
None
,
strength
:
float
=
0.6
,
num_inference_steps
:
int
=
50
,
timesteps
:
List
[
int
]
=
None
,
sigmas
:
Optional
[
List
[
float
]
]
=
None
,
guidance_scale
:
float
=
7.0
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
negative_prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
...
...
@@ -783,10 +783,10 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timestep
s (`List[
in
t]`, *optional*):
Custom
timestep
s to use for the denoising process with schedulers which support a `
timestep
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
Must be in descending order.
sigma
s (`List[
floa
t]`, *optional*):
Custom
sigma
s to use for the denoising process with schedulers which support a `
sigma
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
...
...
@@ -936,7 +936,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
image
=
self
.
image_processor
.
preprocess
(
image
)
# 4. Prepare timesteps
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
timestep
s
)
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
sigmas
=
sigma
s
)
timesteps
,
num_inference_steps
=
self
.
get_timesteps
(
num_inference_steps
,
strength
,
device
)
latent_timestep
=
timesteps
[:
1
].
repeat
(
batch_size
*
num_images_per_prompt
)
# 5. Prepare latent variables
...
...
src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
View file @
a2d424eb
...
...
@@ -679,7 +679,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
num_inference_steps
:
int
=
28
,
timesteps
:
List
[
int
]
=
None
,
sigmas
:
Optional
[
List
[
float
]
]
=
None
,
guidance_scale
:
float
=
7.0
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
negative_prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
...
...
@@ -723,10 +723,10 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timestep
s (`List[
in
t]`, *optional*):
Custom
timestep
s to use for the denoising process with schedulers which support a `
timestep
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
Must be in descending order.
sigma
s (`List[
floa
t]`, *optional*):
Custom
sigma
s to use for the denoising process with schedulers which support a `
sigma
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
...
...
@@ -883,7 +883,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
pooled_prompt_embeds
=
torch
.
cat
([
negative_pooled_prompt_embeds
,
pooled_prompt_embeds
],
dim
=
0
)
# 4. Prepare timesteps
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
timestep
s
)
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
sigmas
=
sigma
s
)
num_warmup_steps
=
max
(
len
(
timesteps
)
-
num_inference_steps
*
self
.
scheduler
.
order
,
0
)
self
.
_num_timesteps
=
len
(
timesteps
)
...
...
src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
View file @
a2d424eb
...
...
@@ -713,7 +713,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
image
:
PipelineImageInput
=
None
,
strength
:
float
=
0.6
,
num_inference_steps
:
int
=
50
,
timesteps
:
List
[
int
]
=
None
,
sigmas
:
Optional
[
List
[
float
]
]
=
None
,
guidance_scale
:
float
=
7.0
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
negative_prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
...
...
@@ -753,10 +753,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timestep
s (`List[
in
t]`, *optional*):
Custom
timestep
s to use for the denoising process with schedulers which support a `
timestep
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
Must be in descending order.
sigma
s (`List[
floa
t]`, *optional*):
Custom
sigma
s to use for the denoising process with schedulers which support a `
sigma
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
...
...
@@ -893,7 +893,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
image
=
self
.
image_processor
.
preprocess
(
image
)
# 4. Prepare timesteps
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
timestep
s
)
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
sigmas
=
sigma
s
)
timesteps
,
num_inference_steps
=
self
.
get_timesteps
(
num_inference_steps
,
strength
,
device
)
latent_timestep
=
timesteps
[:
1
].
repeat
(
batch_size
*
num_images_per_prompt
)
...
...
src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
View file @
a2d424eb
...
...
@@ -806,7 +806,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
padding_mask_crop
:
Optional
[
int
]
=
None
,
strength
:
float
=
0.6
,
num_inference_steps
:
int
=
50
,
timesteps
:
List
[
int
]
=
None
,
sigmas
:
Optional
[
List
[
float
]
]
=
None
,
guidance_scale
:
float
=
7.0
,
negative_prompt
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
negative_prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
...
...
@@ -874,10 +874,10 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timestep
s (`List[
in
t]`, *optional*):
Custom
timestep
s to use for the denoising process with schedulers which support a `
timestep
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
Must be in descending order.
sigma
s (`List[
floa
t]`, *optional*):
Custom
sigma
s to use for the denoising process with schedulers which support a `
sigma
s` argument
in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
...
...
@@ -1007,7 +1007,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
pooled_prompt_embeds
=
torch
.
cat
([
negative_pooled_prompt_embeds
,
pooled_prompt_embeds
],
dim
=
0
)
# 3. Prepare timesteps
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
timestep
s
)
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
self
.
scheduler
,
num_inference_steps
,
device
,
sigmas
=
sigma
s
)
timesteps
,
num_inference_steps
=
self
.
get_timesteps
(
num_inference_steps
,
strength
,
device
)
# check that number of inference steps is not < 1 - as this doesn't make sense
if
num_inference_steps
<
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment