Unverified Commit e98fabc5 authored by 39th president of the United States, probably's avatar 39th president of the United States, probably Committed by GitHub
Browse files

Allow specifying denoising_start and denoising_end as integers representing...

Allow specifying denoising_start and denoising_end as integers representing the discrete timesteps, fixing the XL ensemble not working for many schedulers (#4115)

* Fix the XL ensemble not working for any kerras scheduler sigmas and having an off by one bug

* Update src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

* make sytle

---------

Co-authored-by: Jimmy <39@🇺🇸

.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent fa356bd4
...@@ -119,6 +119,7 @@ a couple community contributors which also helped shape the following `diffusers ...@@ -119,6 +119,7 @@ a couple community contributors which also helped shape the following `diffusers
- [SytanSD](https://github.com/SytanSD) - [SytanSD](https://github.com/SytanSD)
- [bghira](https://github.com/bghira) - [bghira](https://github.com/bghira)
- [Birch-san](https://github.com/Birch-san) - [Birch-san](https://github.com/Birch-san)
- [AmericanPresidentJimmyCarter](https://github.com/AmericanPresidentJimmyCarter)
#### 1.) Ensemble of Expert Denoisers #### 1.) Ensemble of Expert Denoisers
...@@ -128,10 +129,16 @@ expert for the high-noise diffusion stage and the refiner serves as the expert f ...@@ -128,10 +129,16 @@ expert for the high-noise diffusion stage and the refiner serves as the expert f
The advantage of 1.) over 2.) is that it requires less overall denoising steps and therefore should be significantly The advantage of 1.) over 2.) is that it requires less overall denoising steps and therefore should be significantly
faster. The drawback is that one cannot really inspect the output of the base model; it will still be heavily denoised. faster. The drawback is that one cannot really inspect the output of the base model; it will still be heavily denoised.
To use the base model and refiner as an ensemble of expert denoisers, make sure to define the fraction To use the base model and refiner as an ensemble of expert denoisers, make sure to define the span
of timesteps which should be run through the high-noise denoising stage (*i.e.* the base model) and the low-noise of timesteps which should be run through the high-noise denoising stage (*i.e.* the base model) and the low-noise
denoising stage (*i.e.* the refiner model) respectively. This fraction should be set as the [`denoising_end`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.denoising_end) of the base model denoising stage (*i.e.* the refiner model) respectively. We can set the intervals using the [`denoising_end`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.denoising_end) of the base model
and as the [`denoising_start`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline.__call__.denoising_start) of the refiner model. and [`denoising_start`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline.__call__.denoising_start) of the refiner model.
For both `denoising_end` and `denoising_start` a float value between 0 and 1 should be passed.
When passed, the end and start of denoising will be defined by proportions of discrete timesteps as
defined by the model schedule.
Note that this will override `strength` if it is also declared, since the number of denoising steps
is determined by the discrete timesteps the model was trained on and the declared fractional cutoff.
Let's look at an example. Let's look at an example.
First, we import the two pipelines. Since the text encoders and variational autoencoder are the same First, we import the two pipelines. Since the text encoders and variational autoencoder are the same
...@@ -157,31 +164,49 @@ refiner = DiffusionPipeline.from_pretrained( ...@@ -157,31 +164,49 @@ refiner = DiffusionPipeline.from_pretrained(
refiner.to("cuda") refiner.to("cuda")
``` ```
Now we define the number of inference steps and the fraction at which the model shall be run through the Now we define the number of inference steps and the point at which the model shall be run through the
high-noise denoising stage (*i.e.* the base model). high-noise denoising stage (*i.e.* the base model).
```py ```py
n_steps = 40 n_steps = 40
high_noise_frac = 0.7 high_noise_frac = 0.8
``` ```
A fraction of 0.7 means that 70% of the 40 inference steps (28 steps) are run through the base model Stable Diffusion XL base is trained on timesteps 0-999 and Stable Diffusion XL refiner is finetuned
and the remaining 12 steps are run through the refiner. Let's run the two pipelines now. from the base model on low noise timesteps 0-199 inclusive, so we use the base model for the first
Make sure to set `denoising_end` and `denoising_start` to the same values and keep `num_inference_steps` 800 timesteps (high noise) and the refiner for the last 200 timesteps (low noise). Hence, `high_noise_frac`
constant. Also remember that the output of the base model should be in latent space: is set to 0.8, so that all steps 200-999 (the first 80% of denoising timesteps) are performed by the
base model and steps 0-199 (the last 20% of denoising timesteps) are performed by the refiner model.
Remember, the denoising process starts at **high value** (high noise) timesteps and ends at
**low value** (low noise) timesteps.
Let's run the two pipelines now. Make sure to set `denoising_end` and
`denoising_start` to the same values and keep `num_inference_steps` constant. Also remember that
the output of the base model should be in latent space:
```py ```py
prompt = "A majestic lion jumping from a big stone at night" prompt = "A majestic lion jumping from a big stone at night"
image = base(prompt=prompt, num_inference_steps=n_steps, denoising_end=high_noise_frac, output_type="latent").images image = base(
image = refiner(prompt=prompt, num_inference_steps=n_steps, denoising_start=high_noise_frac, image=image).images[0] prompt=prompt,
num_inference_steps=n_steps,
denoising_end=high_noise_frac,
output_type="latent",
).images
image = refiner(
prompt=prompt,
num_inference_steps=n_steps,
denoising_start=high_noise_frac,
image=image,
).images[0]
``` ```
Let's have a look at the image Let's have a look at the images
| Original Image | Ensemble of Denoisers Experts | | Original Image | Ensemble of Denoisers Experts |
|---|---| |---|---|
| ![lion_base](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lion_base.png) | ![lion_ref](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lion_refined.png) | ![lion_base_timesteps](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lion_base_timesteps.png) | ![lion_refined_timesteps](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lion_refined_timesteps.png)
If we would have just run the base model on the same 40 steps, the image would have been arguably less detailed (e.g. the lion eyes and nose): If we would have just run the base model on the same 40 steps, the image would have been arguably less detailed (e.g. the lion eyes and nose):
...@@ -271,7 +296,6 @@ image = pipe( ...@@ -271,7 +296,6 @@ image = pipe(
image=init_image, image=init_image,
mask_image=mask_image, mask_image=mask_image,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
strength=0.80,
denoising_start=high_noise_frac, denoising_start=high_noise_frac,
output_type="latent", output_type="latent",
).images ).images
......
...@@ -593,11 +593,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -593,11 +593,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
expense of slower inference. expense of slower inference.
denoising_end (`float`, *optional*): denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to completed before it is intentionally prematurely terminated. As a result, the returned sample will
0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50) still retain a substantial amount of noise as determined by the discrete timesteps selected by the
denoising steps. As a result, the returned sample will still retain a substantial amount of noise. The scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 7.5): guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
...@@ -774,9 +773,15 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -774,9 +773,15 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7.1 Apply denoising_end # 7.1 Apply denoising_end
if denoising_end is not None: if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
num_inference_steps = int(round(denoising_end * num_inference_steps)) discrete_timestep_cutoff = int(
timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] round(
self.scheduler.config.num_train_timesteps
- (denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
......
...@@ -497,10 +497,22 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -497,10 +497,22 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
else: else:
t_start = int(round(denoising_start * num_inference_steps)) t_start = 0
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
# Strength is irrelevant if we directly request a timestep to start at;
# that is, strength is determined by the denoising_start instead.
if denoising_start is not None:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
def prepare_latents( def prepare_latents(
...@@ -687,26 +699,24 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -687,26 +699,24 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
will be used as a starting point, adding more noise to it the larger the `strength`. The number of will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
be maximum and the denoising process will run for the full number of iterations specified in be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of
`denoising_start` being declared as an integer, the value of `strength` will be ignored.
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. expense of slower inference.
denoising_start (`float`, *optional*): denoising_start (`float`, *optional*):
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
bypassed before it is initiated. For example, if `denoising_start` is set to 0.7 and bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
num_inference_steps is fixed at 50, the process will begin only from the 35th (i.e., 0.7 * 50) it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
denoising step. Consequently, the initial part of the denoising process is skipped and it is assumed strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
that the passed `image` is a partly denoised image. The `denoising_start` parameter is particularly is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as
detailed in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
denoising_end (`float`, *optional*): denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to completed before it is intentionally prematurely terminated. As a result, the returned sample will
0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50) still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
denoising steps. As a result, the returned sample will still retain a substantial amount of noise (ca. denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
30%) and should be denoised by a successor pipeline that has `denoising_start` set to 0.7 so that it final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
only denoised the final 30%. The denoising_end parameter should ideally be utilized when this pipeline
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
guidance_scale (`float`, *optional*, defaults to 7.5): guidance_scale (`float`, *optional*, defaults to 7.5):
...@@ -845,11 +855,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -845,11 +855,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
image = self.image_processor.preprocess(image) image = self.image_processor.preprocess(image)
# 5. Prepare timesteps # 5. Prepare timesteps
original_num_steps = num_inference_steps # save for denoising_start/end later def denoising_value_valid(dnv):
return type(denoising_end) == float and 0 < dnv < 1
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps( timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps, strength, device, denoising_start=denoising_start num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid else None
) )
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
...@@ -899,18 +910,26 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -899,18 +910,26 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 9.1 Apply denoising_end # 9.1 Apply denoising_end
if denoising_end is not None and denoising_start is not None: if (
if denoising_start >= denoising_end: denoising_end is not None
and denoising_start is not None
and denoising_value_valid(denoising_end)
and denoising_value_valid(denoising_start)
and denoising_start >= denoising_end
):
raise ValueError( raise ValueError(
f"`denoising_end`: {denoising_end} cannot be larger than `denoising_start`: {denoising_start}." f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
+ f" {denoising_end} when using type float."
) )
elif denoising_end is not None and denoising_value_valid(denoising_end):
skipped_final_steps = int(round((1 - denoising_end) * original_num_steps)) discrete_timestep_cutoff = int(
num_inference_steps = num_inference_steps - skipped_final_steps round(
timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] self.scheduler.config.num_train_timesteps
elif denoising_end is not None: - (denoising_end * self.scheduler.config.num_train_timesteps)
num_inference_steps = int(round(denoising_end * num_inference_steps)) )
timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] )
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
......
...@@ -731,10 +731,22 @@ class StableDiffusionXLInpaintPipeline( ...@@ -731,10 +731,22 @@ class StableDiffusionXLInpaintPipeline(
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0) t_start = max(num_inference_steps - init_timestep, 0)
else: else:
t_start = int(round(denoising_start * num_inference_steps)) t_start = 0
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
# Strength is irrelevant if we directly request a timestep to start at;
# that is, strength is determined by the denoising_start instead.
if denoising_start is not None:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps)
return timesteps, num_inference_steps - t_start return timesteps, num_inference_steps - t_start
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
...@@ -861,26 +873,24 @@ class StableDiffusionXLInpaintPipeline( ...@@ -861,26 +873,24 @@ class StableDiffusionXLInpaintPipeline(
`strength`. The number of denoising steps depends on the amount of noise initially added. When `strength`. The number of denoising steps depends on the amount of noise initially added. When
`strength` is 1, added noise will be maximum and the denoising process will run for the full number of `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
portion of the reference `image`. portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
integer, the value of `strength` will be ignored.
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. expense of slower inference.
denoising_start (`float`, *optional*): denoising_start (`float`, *optional*):
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
bypassed before it is initiated. For example, if `denoising_start` is set to 0.7 and bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
num_inference_steps is fixed at 50, the process will begin only from the 35th (i.e., 0.7 * 50) it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
denoising step. Consequently, the initial part of the denoising process is skipped and it is assumed strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
that the passed `image` is a partly denoised image. The `denoising_start` parameter is particularly is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
beneficial when this pipeline is integrated into a "Mixture of Denoisers" multi-pipeline setup, as
detailed in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
denoising_end (`float`, *optional*): denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. For instance, if denoising_end is set to completed before it is intentionally prematurely terminated. As a result, the returned sample will
0.7 and `num_inference_steps` is fixed at 50, the process will execute only 35 (i.e., 0.7 * 50) still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
denoising steps. As a result, the returned sample will still retain a substantial amount of noise (ca. denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
30%) and should be denoised by a successor pipeline that has `denoising_start` set to 0.7 so that it final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
only denoised the final 30%. The denoising_end parameter should ideally be utilized when this pipeline
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output). Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
guidance_scale (`float`, *optional*, defaults to 7.5): guidance_scale (`float`, *optional*, defaults to 7.5):
...@@ -1034,10 +1044,12 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1034,10 +1044,12 @@ class StableDiffusionXLInpaintPipeline(
) )
# 4. set timesteps # 4. set timesteps
original_num_steps = num_inference_steps # save for denoising_start/end later def denoising_value_valid(dnv):
return type(denoising_end) == float and 0 < dnv < 1
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps( timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps, strength, device, denoising_start=denoising_start num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid else None
) )
# check that number of inference steps is not < 1 - as this doesn't make sense # check that number of inference steps is not < 1 - as this doesn't make sense
if num_inference_steps < 1: if num_inference_steps < 1:
...@@ -1147,18 +1159,26 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1147,18 +1159,26 @@ class StableDiffusionXLInpaintPipeline(
# 11. Denoising loop # 11. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
if denoising_end is not None and denoising_start is not None: if (
if denoising_start >= denoising_end: denoising_end is not None
and denoising_start is not None
and denoising_value_valid(denoising_end)
and denoising_value_valid(denoising_start)
and denoising_start >= denoising_end
):
raise ValueError( raise ValueError(
f"`denoising_end`: {denoising_end} cannot be larger than `denoising_start`: {denoising_start}." f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
+ f" {denoising_end} when using type float."
) )
elif denoising_end is not None and denoising_value_valid(denoising_end):
skipped_final_steps = int(round((1 - denoising_end) * original_num_steps)) discrete_timestep_cutoff = int(
num_inference_steps = num_inference_steps - skipped_final_steps round(
timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] self.scheduler.config.num_train_timesteps
elif denoising_end is not None: - (denoising_end * self.scheduler.config.num_train_timesteps)
num_inference_steps = int(round(denoising_end * num_inference_steps)) )
timesteps = timesteps[: num_warmup_steps + self.scheduler.order * num_inference_steps] )
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
......
...@@ -268,7 +268,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -268,7 +268,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
pipe_2 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device) pipe_2 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device)
pipe_2.unet.set_default_attn_processor() pipe_2.unet.set_default_attn_processor()
def assert_run_mixture(num_steps, split, scheduler_cls_orig): def assert_run_mixture(
num_steps,
split,
scheduler_cls_orig,
expected_tss,
num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps,
):
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = num_steps inputs["num_inference_steps"] = num_steps
...@@ -282,9 +288,8 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -282,9 +288,8 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
pipe_1.scheduler.set_timesteps(num_steps) pipe_1.scheduler.set_timesteps(num_steps)
expected_steps = pipe_1.scheduler.timesteps.tolist() expected_steps = pipe_1.scheduler.timesteps.tolist()
split_id = int(round(split * num_steps)) * pipe_1.scheduler.order expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
expected_steps_1 = expected_steps[:split_id] expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))
expected_steps_2 = expected_steps[split_id:]
# now we monkey patch step `done_steps` # now we monkey patch step `done_steps`
# list into the step function for testing # list into the step function for testing
...@@ -297,27 +302,242 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -297,27 +302,242 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
scheduler_cls.step = new_step scheduler_cls.step = new_step
inputs_1 = {**inputs, **{"denoising_end": split, "output_type": "latent"}} inputs_1 = {
**inputs,
**{
"denoising_end": 1.0 - (split / num_train_timesteps),
"output_type": "latent",
},
}
latents = pipe_1(**inputs_1).images[0] latents = pipe_1(**inputs_1).images[0]
assert expected_steps_1 == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}" assert expected_steps_1 == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
inputs_2 = {**inputs, **{"denoising_start": split, "image": latents}} inputs_2 = {
**inputs,
**{
"denoising_start": 1.0 - (split / num_train_timesteps),
"image": latents,
},
}
pipe_2(**inputs_2).images[0] pipe_2(**inputs_2).images[0]
assert expected_steps_2 == done_steps[len(expected_steps_1) :] assert expected_steps_2 == done_steps[len(expected_steps_1) :]
assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}" assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
for steps in [5, 8]: steps = 10
for split in [0.33, 0.49, 0.71]: for split in [300, 500, 700]:
for scheduler_cls in [ for scheduler_cls_timesteps in [
(DDIMScheduler, [901, 801, 701, 601, 501, 401, 301, 201, 101, 1]),
(EulerDiscreteScheduler, [901, 801, 701, 601, 501, 401, 301, 201, 101, 1]),
(DPMSolverMultistepScheduler, [901, 811, 721, 631, 541, 451, 361, 271, 181, 91]),
(UniPCMultistepScheduler, [901, 811, 721, 631, 541, 451, 361, 271, 181, 91]),
(
HeunDiscreteScheduler,
[
901.0,
801.0,
801.0,
701.0,
701.0,
601.0,
601.0,
501.0,
501.0,
401.0,
401.0,
301.0,
301.0,
201.0,
201.0,
101.0,
101.0,
1.0,
1.0,
],
),
]:
assert_run_mixture(steps, split, scheduler_cls_timesteps[0], scheduler_cls_timesteps[1])
steps = 25
for split in [300, 500, 700]:
for scheduler_cls_timesteps in [
(
DDIMScheduler, DDIMScheduler,
[
961,
921,
881,
841,
801,
761,
721,
681,
641,
601,
561,
521,
481,
441,
401,
361,
321,
281,
241,
201,
161,
121,
81,
41,
1,
],
),
(
EulerDiscreteScheduler, EulerDiscreteScheduler,
[
961.0,
921.0,
881.0,
841.0,
801.0,
761.0,
721.0,
681.0,
641.0,
601.0,
561.0,
521.0,
481.0,
441.0,
401.0,
361.0,
321.0,
281.0,
241.0,
201.0,
161.0,
121.0,
81.0,
41.0,
1.0,
],
),
(
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
[
951,
913,
875,
837,
799,
761,
723,
685,
647,
609,
571,
533,
495,
457,
419,
381,
343,
305,
267,
229,
191,
153,
115,
77,
39,
],
),
(
UniPCMultistepScheduler, UniPCMultistepScheduler,
[
951,
913,
875,
837,
799,
761,
723,
685,
647,
609,
571,
533,
495,
457,
419,
381,
343,
305,
267,
229,
191,
153,
115,
77,
39,
],
),
(
HeunDiscreteScheduler, HeunDiscreteScheduler,
[
961.0,
921.0,
921.0,
881.0,
881.0,
841.0,
841.0,
801.0,
801.0,
761.0,
761.0,
721.0,
721.0,
681.0,
681.0,
641.0,
641.0,
601.0,
601.0,
561.0,
561.0,
521.0,
521.0,
481.0,
481.0,
441.0,
441.0,
401.0,
401.0,
361.0,
361.0,
321.0,
321.0,
281.0,
281.0,
241.0,
241.0,
201.0,
201.0,
161.0,
161.0,
121.0,
121.0,
81.0,
81.0,
41.0,
41.0,
1.0,
1.0,
],
),
]: ]:
assert_run_mixture(steps, split, scheduler_cls) assert_run_mixture(steps, split, scheduler_cls_timesteps[0], scheduler_cls_timesteps[1])
def test_stable_diffusion_three_xl_mixture_of_denoiser(self): def test_stable_diffusion_three_xl_mixture_of_denoiser(self):
components = self.get_dummy_components() components = self.get_dummy_components()
...@@ -328,7 +548,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -328,7 +548,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
pipe_3 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device) pipe_3 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device)
pipe_3.unet.set_default_attn_processor() pipe_3.unet.set_default_attn_processor()
def assert_run_mixture(num_steps, split_1, split_2, scheduler_cls_orig): def assert_run_mixture(
num_steps,
split_1,
split_2,
scheduler_cls_orig,
num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps,
):
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = num_steps inputs["num_inference_steps"] = num_steps
...@@ -343,11 +569,15 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -343,11 +569,15 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
pipe_1.scheduler.set_timesteps(num_steps) pipe_1.scheduler.set_timesteps(num_steps)
expected_steps = pipe_1.scheduler.timesteps.tolist() expected_steps = pipe_1.scheduler.timesteps.tolist()
split_id_1 = int(round(split_1 * num_steps)) * pipe_1.scheduler.order split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1))
split_id_2 = int(round(split_2 * num_steps)) * pipe_1.scheduler.order split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2))
expected_steps_1 = expected_steps[:split_id_1] expected_steps_1 = expected_steps[:split_1_ts]
expected_steps_2 = expected_steps[split_id_1:split_id_2] expected_steps_2 = expected_steps[split_1_ts:split_2_ts]
expected_steps_3 = expected_steps[split_id_2:] expected_steps_3 = expected_steps[split_2_ts:]
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
# now we monkey patch step `done_steps` # now we monkey patch step `done_steps`
# list into the step function for testing # list into the step function for testing
...@@ -367,6 +597,19 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -367,6 +597,19 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
expected_steps_1 == done_steps expected_steps_1 == done_steps
), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}" ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
with self.assertRaises(ValueError) as cm:
inputs_2 = {
**inputs,
**{
"denoising_start": split_2,
"denoising_end": split_1,
"image": latents,
"output_type": "latent",
},
}
pipe_2(**inputs_2).images[0]
assert "cannot be larger than or equal to `denoising_end`" in str(cm.exception)
inputs_2 = { inputs_2 = {
**inputs, **inputs,
**{"denoising_start": split_1, "denoising_end": split_2, "image": latents, "output_type": "latent"}, **{"denoising_start": split_1, "denoising_end": split_2, "image": latents, "output_type": "latent"},
...@@ -383,7 +626,7 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -383,7 +626,7 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
expected_steps == done_steps expected_steps == done_steps
), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}" ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
for steps in [7, 11]: for steps in [7, 11, 20]:
for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]): for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
for scheduler_cls in [ for scheduler_cls in [
DDIMScheduler, DDIMScheduler,
......
...@@ -264,7 +264,9 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -264,7 +264,9 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
pipe_2 = StableDiffusionXLInpaintPipeline(**components).to(torch_device) pipe_2 = StableDiffusionXLInpaintPipeline(**components).to(torch_device)
pipe_2.unet.set_default_attn_processor() pipe_2.unet.set_default_attn_processor()
def assert_run_mixture(num_steps, split, scheduler_cls_orig): def assert_run_mixture(
num_steps, split, scheduler_cls_orig, num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps
):
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = num_steps inputs["num_inference_steps"] = num_steps
...@@ -278,9 +280,12 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -278,9 +280,12 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
pipe_1.scheduler.set_timesteps(num_steps) pipe_1.scheduler.set_timesteps(num_steps)
expected_steps = pipe_1.scheduler.timesteps.tolist() expected_steps = pipe_1.scheduler.timesteps.tolist()
split_id = int(round(split * num_steps)) * pipe_1.scheduler.order split_ts = num_train_timesteps - int(round(num_train_timesteps * split))
expected_steps_1 = expected_steps[:split_id] expected_steps_1 = expected_steps[:split_ts]
expected_steps_2 = expected_steps[split_id:] expected_steps_2 = expected_steps[split_ts:]
expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps))
# now we monkey patch step `done_steps` # now we monkey patch step `done_steps`
# list into the step function for testing # list into the step function for testing
...@@ -304,7 +309,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -304,7 +309,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert expected_steps_2 == done_steps[len(expected_steps_1) :] assert expected_steps_2 == done_steps[len(expected_steps_1) :]
assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}" assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
for steps in [5, 8]: for steps in [5, 8, 20]:
for split in [0.33, 0.49, 0.71]: for split in [0.33, 0.49, 0.71]:
for scheduler_cls in [ for scheduler_cls in [
DDIMScheduler, DDIMScheduler,
...@@ -324,7 +329,13 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -324,7 +329,13 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
pipe_3 = StableDiffusionXLInpaintPipeline(**components).to(torch_device) pipe_3 = StableDiffusionXLInpaintPipeline(**components).to(torch_device)
pipe_3.unet.set_default_attn_processor() pipe_3.unet.set_default_attn_processor()
def assert_run_mixture(num_steps, split_1, split_2, scheduler_cls_orig): def assert_run_mixture(
num_steps,
split_1,
split_2,
scheduler_cls_orig,
num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps,
):
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = num_steps inputs["num_inference_steps"] = num_steps
...@@ -339,11 +350,15 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -339,11 +350,15 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
pipe_1.scheduler.set_timesteps(num_steps) pipe_1.scheduler.set_timesteps(num_steps)
expected_steps = pipe_1.scheduler.timesteps.tolist() expected_steps = pipe_1.scheduler.timesteps.tolist()
split_id_1 = int(round(split_1 * num_steps)) * pipe_1.scheduler.order split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1))
split_id_2 = int(round(split_2 * num_steps)) * pipe_1.scheduler.order split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2))
expected_steps_1 = expected_steps[:split_id_1] expected_steps_1 = expected_steps[:split_1_ts]
expected_steps_2 = expected_steps[split_id_1:split_id_2] expected_steps_2 = expected_steps[split_1_ts:split_2_ts]
expected_steps_3 = expected_steps[split_id_2:] expected_steps_3 = expected_steps[split_2_ts:]
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
# now we monkey patch step `done_steps` # now we monkey patch step `done_steps`
# list into the step function for testing # list into the step function for testing
...@@ -379,7 +394,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -379,7 +394,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
expected_steps == done_steps expected_steps == done_steps
), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}" ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
for steps in [7, 11]: for steps in [7, 11, 20]:
for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]): for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
for scheduler_cls in [ for scheduler_cls in [
DDIMScheduler, DDIMScheduler,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment