Unverified Commit b934215d authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[scheduler] support custom `timesteps` and `sigmas` (#7817)



* support custom sigmas and timesteps, dpm euler

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 5ed3abd3
...@@ -212,6 +212,62 @@ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True). ...@@ -212,6 +212,62 @@ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
``` ```
## Custom Timestep Schedules
With all our schedulers, you can choose one of the popular timestep schedules using configurations such as `timestep_spacing`, `interpolation_type`, and `use_karras_sigmas`. Some schedulers also provide the flexibility to use a custom timestep schedule. You can use any list of arbitrary timesteps, we will use the AYS timestep schedule here as example. It is a set of 10-step optimized timestep schedules released by researchers from Nvidia that can achieve significantly better quality compared to the preset timestep schedules. You can read more about their research [here](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/).
```python
from diffusers.schedulers import AysSchedules
sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
print(sampling_schedule)
```
```
[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]
```
You can then create a pipeline and pass this custom timestep schedule to it as `timesteps`.
```python
pipe = StableDiffusionXLPipeline.from_pretrained(
"SG161222/RealVisXL_V4.0",
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, algorithm_type="sde-dpmsolver++")
prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
generator = torch.Generator(device="cpu").manual_seed(2487854446)
image = pipe(
prompt=prompt,
negative_prompt="",
generator=generator,
timesteps=sampling_schedule,
).images[0]
```
The generated image has better quality than the default linear timestep schedule for the same number of steps, and it is similar to the default timestep scheduler when running for 25 steps.
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ays.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">AYS timestep schedule 10 steps</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/10.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 10 steps</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/25.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">Linearly-spaced timestep schedule 25 steps</figcaption>
</div>
</div>
> [!TIP]
> 🤗 Diffusers currently only supports `timesteps` and `sigmas` for a selected list of schedulers and pipelines, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to extend feature to a scheduler and pipeline that does not currently support it!
## Models ## Models
Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of re-downloading them. Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of re-downloading them.
......
...@@ -156,6 +156,7 @@ def retrieve_timesteps( ...@@ -156,6 +156,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -171,14 +172,18 @@ def retrieve_timesteps( ...@@ -171,14 +172,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -189,6 +194,16 @@ def retrieve_timesteps( ...@@ -189,6 +194,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -865,6 +880,7 @@ class AnimateDiffSDXLPipeline( ...@@ -865,6 +880,7 @@ class AnimateDiffSDXLPipeline(
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
denoising_end: Optional[float] = None, denoising_end: Optional[float] = None,
guidance_scale: float = 5.0, guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
...@@ -923,6 +939,10 @@ class AnimateDiffSDXLPipeline( ...@@ -923,6 +939,10 @@ class AnimateDiffSDXLPipeline(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order. passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
denoising_end (`float`, *optional*): denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will completed before it is intentionally prematurely terminated. As a result, the returned sample will
...@@ -1104,7 +1124,9 @@ class AnimateDiffSDXLPipeline( ...@@ -1104,7 +1124,9 @@ class AnimateDiffSDXLPipeline(
) )
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
......
...@@ -137,6 +137,7 @@ def retrieve_timesteps( ...@@ -137,6 +137,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -152,14 +153,18 @@ def retrieve_timesteps( ...@@ -152,14 +153,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -170,6 +175,16 @@ def retrieve_timesteps( ...@@ -170,6 +175,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -750,6 +765,7 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -750,6 +765,7 @@ class AnimateDiffVideoToVideoPipeline(
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
strength: float = 0.8, strength: float = 0.8,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
...@@ -783,6 +799,14 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -783,6 +799,14 @@ class AnimateDiffVideoToVideoPipeline(
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
expense of slower inference. expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
strength (`float`, *optional*, defaults to 0.8): strength (`float`, *optional*, defaults to 0.8):
Higher strength leads to more differences between original video and generated video. Higher strength leads to more differences between original video and generated video.
guidance_scale (`float`, *optional*, defaults to 7.5): guidance_scale (`float`, *optional*, defaults to 7.5):
...@@ -912,7 +936,9 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -912,7 +936,9 @@ class AnimateDiffVideoToVideoPipeline(
) )
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
......
...@@ -97,6 +97,7 @@ def retrieve_timesteps( ...@@ -97,6 +97,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -112,14 +113,18 @@ def retrieve_timesteps( ...@@ -112,14 +113,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -130,6 +135,16 @@ def retrieve_timesteps( ...@@ -130,6 +135,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -892,6 +907,7 @@ class StableDiffusionControlNetPipeline( ...@@ -892,6 +907,7 @@ class StableDiffusionControlNetPipeline(
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
...@@ -941,6 +957,10 @@ class StableDiffusionControlNetPipeline( ...@@ -941,6 +957,10 @@ class StableDiffusionControlNetPipeline(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order. passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.5): guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
...@@ -1162,7 +1182,9 @@ class StableDiffusionControlNetPipeline( ...@@ -1162,7 +1182,9 @@ class StableDiffusionControlNetPipeline(
assert False assert False
# 5. Prepare timesteps # 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 6. Prepare latent variables # 6. Prepare latent variables
......
...@@ -79,6 +79,7 @@ def retrieve_timesteps( ...@@ -79,6 +79,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -94,14 +95,18 @@ def retrieve_timesteps( ...@@ -94,14 +95,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -112,6 +117,16 @@ def retrieve_timesteps( ...@@ -112,6 +117,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -673,6 +688,7 @@ class AltDiffusionPipeline( ...@@ -673,6 +688,7 @@ class AltDiffusionPipeline(
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
...@@ -848,7 +864,9 @@ class AltDiffusionPipeline( ...@@ -848,7 +864,9 @@ class AltDiffusionPipeline(
image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
......
...@@ -119,6 +119,7 @@ def retrieve_timesteps( ...@@ -119,6 +119,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -134,14 +135,18 @@ def retrieve_timesteps( ...@@ -134,14 +135,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -152,6 +157,16 @@ def retrieve_timesteps( ...@@ -152,6 +157,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -753,6 +768,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -753,6 +768,7 @@ class AltDiffusionImg2ImgPipeline(
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
...@@ -919,7 +935,9 @@ class AltDiffusionImg2ImgPipeline( ...@@ -919,7 +935,9 @@ class AltDiffusionImg2ImgPipeline(
image = self.image_processor.preprocess(image) image = self.image_processor.preprocess(image)
# 5. set timesteps # 5. set timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
......
...@@ -63,6 +63,7 @@ def retrieve_timesteps( ...@@ -63,6 +63,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -78,14 +79,18 @@ def retrieve_timesteps( ...@@ -78,14 +79,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -96,6 +101,16 @@ def retrieve_timesteps( ...@@ -96,6 +101,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
......
...@@ -67,6 +67,7 @@ def retrieve_timesteps( ...@@ -67,6 +67,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -82,14 +83,18 @@ def retrieve_timesteps( ...@@ -82,14 +83,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -100,6 +105,16 @@ def retrieve_timesteps( ...@@ -100,6 +105,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
......
...@@ -175,6 +175,7 @@ def retrieve_timesteps( ...@@ -175,6 +175,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -190,14 +191,18 @@ def retrieve_timesteps( ...@@ -190,14 +191,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -208,6 +213,16 @@ def retrieve_timesteps( ...@@ -208,6 +213,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -672,6 +687,7 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -672,6 +687,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
negative_prompt: str = "", negative_prompt: str = "",
num_inference_steps: int = 20, num_inference_steps: int = 20,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 4.5, guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = None, height: Optional[int] = None,
...@@ -707,8 +723,13 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -707,8 +723,13 @@ class PixArtAlphaPipeline(DiffusionPipeline):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. expense of slower inference.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
timesteps are used. Must be in descending order. in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.5): guidance_scale (`float`, *optional*, defaults to 4.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen `guidance_scale` is defined as `w` of equation 2. of [Imagen
...@@ -837,7 +858,9 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -837,7 +858,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latents. # 5. Prepare latents.
latent_channels = self.transformer.config.in_channels latent_channels = self.transformer.config.in_channels
......
...@@ -119,6 +119,7 @@ def retrieve_timesteps( ...@@ -119,6 +119,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -134,14 +135,18 @@ def retrieve_timesteps( ...@@ -134,14 +135,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -152,6 +157,16 @@ def retrieve_timesteps( ...@@ -152,6 +157,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -599,6 +614,7 @@ class PixArtSigmaPipeline(DiffusionPipeline): ...@@ -599,6 +614,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
negative_prompt: str = "", negative_prompt: str = "",
num_inference_steps: int = 20, num_inference_steps: int = 20,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 4.5, guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = None, height: Optional[int] = None,
...@@ -634,8 +650,13 @@ class PixArtSigmaPipeline(DiffusionPipeline): ...@@ -634,8 +650,13 @@ class PixArtSigmaPipeline(DiffusionPipeline):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. expense of slower inference.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
timesteps are used. Must be in descending order. in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.5): guidance_scale (`float`, *optional*, defaults to 4.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen `guidance_scale` is defined as `w` of equation 2. of [Imagen
...@@ -763,7 +784,9 @@ class PixArtSigmaPipeline(DiffusionPipeline): ...@@ -763,7 +784,9 @@ class PixArtSigmaPipeline(DiffusionPipeline):
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latents. # 5. Prepare latents.
latent_channels = self.transformer.config.in_channels latent_channels = self.transformer.config.in_channels
......
...@@ -75,6 +75,7 @@ def retrieve_timesteps( ...@@ -75,6 +75,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -90,14 +91,18 @@ def retrieve_timesteps( ...@@ -90,14 +91,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -108,6 +113,16 @@ def retrieve_timesteps( ...@@ -108,6 +113,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -744,6 +759,7 @@ class StableDiffusionPipeline( ...@@ -744,6 +759,7 @@ class StableDiffusionPipeline(
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
...@@ -780,6 +796,10 @@ class StableDiffusionPipeline( ...@@ -780,6 +796,10 @@ class StableDiffusionPipeline(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order. passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.5): guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
...@@ -929,7 +949,9 @@ class StableDiffusionPipeline( ...@@ -929,7 +949,9 @@ class StableDiffusionPipeline(
) )
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
......
...@@ -115,6 +115,7 @@ def retrieve_timesteps( ...@@ -115,6 +115,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -130,14 +131,18 @@ def retrieve_timesteps( ...@@ -130,14 +131,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -148,6 +153,16 @@ def retrieve_timesteps( ...@@ -148,6 +153,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -833,6 +848,7 @@ class StableDiffusionImg2ImgPipeline( ...@@ -833,6 +848,7 @@ class StableDiffusionImg2ImgPipeline(
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
...@@ -875,6 +891,10 @@ class StableDiffusionImg2ImgPipeline( ...@@ -875,6 +891,10 @@ class StableDiffusionImg2ImgPipeline(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order. passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.5): guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
...@@ -1009,7 +1029,9 @@ class StableDiffusionImg2ImgPipeline( ...@@ -1009,7 +1029,9 @@ class StableDiffusionImg2ImgPipeline(
image = self.image_processor.preprocess(image) image = self.image_processor.preprocess(image)
# 5. set timesteps # 5. set timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
......
...@@ -179,6 +179,7 @@ def retrieve_timesteps( ...@@ -179,6 +179,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -194,14 +195,18 @@ def retrieve_timesteps( ...@@ -194,14 +195,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -212,6 +217,16 @@ def retrieve_timesteps( ...@@ -212,6 +217,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -984,6 +999,7 @@ class StableDiffusionInpaintPipeline( ...@@ -984,6 +999,7 @@ class StableDiffusionInpaintPipeline(
strength: float = 1.0, strength: float = 1.0,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
...@@ -1046,6 +1062,10 @@ class StableDiffusionInpaintPipeline( ...@@ -1046,6 +1062,10 @@ class StableDiffusionInpaintPipeline(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order. passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.5): guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
...@@ -1220,7 +1240,9 @@ class StableDiffusionInpaintPipeline( ...@@ -1220,7 +1240,9 @@ class StableDiffusionInpaintPipeline(
) )
# 4. set timesteps # 4. set timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps( timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps=num_inference_steps, strength=strength, device=device num_inference_steps=num_inference_steps, strength=strength, device=device
) )
......
...@@ -80,6 +80,7 @@ def retrieve_timesteps( ...@@ -80,6 +80,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -95,14 +96,18 @@ def retrieve_timesteps( ...@@ -95,14 +96,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -113,6 +118,16 @@ def retrieve_timesteps( ...@@ -113,6 +118,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -719,6 +734,7 @@ class StableDiffusionLDM3DPipeline( ...@@ -719,6 +734,7 @@ class StableDiffusionLDM3DPipeline(
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 49, num_inference_steps: int = 49,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 5.0, guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
...@@ -751,6 +767,14 @@ class StableDiffusionLDM3DPipeline( ...@@ -751,6 +767,14 @@ class StableDiffusionLDM3DPipeline(
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 5.0): guidance_scale (`float`, *optional*, defaults to 5.0):
A higher guidance scale value encourages the model to generate images closely linked to the text A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
...@@ -888,7 +912,9 @@ class StableDiffusionLDM3DPipeline( ...@@ -888,7 +912,9 @@ class StableDiffusionLDM3DPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
......
...@@ -80,6 +80,7 @@ def retrieve_timesteps( ...@@ -80,6 +80,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -95,14 +96,18 @@ def retrieve_timesteps( ...@@ -95,14 +96,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -113,6 +118,16 @@ def retrieve_timesteps( ...@@ -113,6 +118,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
......
...@@ -107,6 +107,7 @@ def retrieve_timesteps( ...@@ -107,6 +107,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -122,14 +123,18 @@ def retrieve_timesteps( ...@@ -122,14 +123,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -140,6 +145,16 @@ def retrieve_timesteps( ...@@ -140,6 +145,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -820,6 +835,7 @@ class StableDiffusionXLPipeline( ...@@ -820,6 +835,7 @@ class StableDiffusionXLPipeline(
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
denoising_end: Optional[float] = None, denoising_end: Optional[float] = None,
guidance_scale: float = 5.0, guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
...@@ -876,6 +892,10 @@ class StableDiffusionXLPipeline( ...@@ -876,6 +892,10 @@ class StableDiffusionXLPipeline(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order. passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
denoising_end (`float`, *optional*): denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will completed before it is intentionally prematurely terminated. As a result, the returned sample will
...@@ -1075,7 +1095,9 @@ class StableDiffusionXLPipeline( ...@@ -1075,7 +1095,9 @@ class StableDiffusionXLPipeline(
) )
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
......
...@@ -124,6 +124,7 @@ def retrieve_timesteps( ...@@ -124,6 +124,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -139,14 +140,18 @@ def retrieve_timesteps( ...@@ -139,14 +140,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -157,6 +162,16 @@ def retrieve_timesteps( ...@@ -157,6 +162,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -964,6 +979,7 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -964,6 +979,7 @@ class StableDiffusionXLImg2ImgPipeline(
strength: float = 0.3, strength: float = 0.3,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
denoising_start: Optional[float] = None, denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None, denoising_end: Optional[float] = None,
guidance_scale: float = 5.0, guidance_scale: float = 5.0,
...@@ -1022,6 +1038,10 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1022,6 +1038,10 @@ class StableDiffusionXLImg2ImgPipeline(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order. passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
denoising_start (`float`, *optional*): denoising_start (`float`, *optional*):
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
...@@ -1237,7 +1257,9 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1237,7 +1257,9 @@ class StableDiffusionXLImg2ImgPipeline(
def denoising_value_valid(dnv): def denoising_value_valid(dnv):
return isinstance(dnv, float) and 0 < dnv < 1 return isinstance(dnv, float) and 0 < dnv < 1
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps( timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps, num_inference_steps,
strength, strength,
......
...@@ -269,6 +269,7 @@ def retrieve_timesteps( ...@@ -269,6 +269,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -284,14 +285,18 @@ def retrieve_timesteps( ...@@ -284,14 +285,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -302,6 +307,16 @@ def retrieve_timesteps( ...@@ -302,6 +307,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -1199,6 +1214,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1199,6 +1214,7 @@ class StableDiffusionXLInpaintPipeline(
strength: float = 0.9999, strength: float = 0.9999,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
denoising_start: Optional[float] = None, denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None, denoising_end: Optional[float] = None,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
...@@ -1281,6 +1297,10 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1281,6 +1297,10 @@ class StableDiffusionXLInpaintPipeline(
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order. passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
denoising_start (`float`, *optional*): denoising_start (`float`, *optional*):
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
...@@ -1498,7 +1518,9 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1498,7 +1518,9 @@ class StableDiffusionXLInpaintPipeline(
def denoising_value_valid(dnv): def denoising_value_valid(dnv):
return isinstance(dnv, float) and 0 < dnv < 1 return isinstance(dnv, float) and 0 < dnv < 1
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps( timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps, num_inference_steps,
strength, strength,
......
...@@ -83,6 +83,66 @@ def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: s ...@@ -83,6 +83,66 @@ def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: s
return outputs return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
@dataclass @dataclass
class StableVideoDiffusionPipelineOutput(BaseOutput): class StableVideoDiffusionPipelineOutput(BaseOutput):
r""" r"""
...@@ -343,6 +403,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -343,6 +403,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
width: int = 1024, width: int = 1024,
num_frames: Optional[int] = None, num_frames: Optional[int] = None,
num_inference_steps: int = 25, num_inference_steps: int = 25,
sigmas: Optional[List[float]] = None,
min_guidance_scale: float = 1.0, min_guidance_scale: float = 1.0,
max_guidance_scale: float = 3.0, max_guidance_scale: float = 3.0,
fps: int = 7, fps: int = 7,
...@@ -374,6 +435,10 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -374,6 +435,10 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
num_inference_steps (`int`, *optional*, defaults to 25): num_inference_steps (`int`, *optional*, defaults to 25):
The number of denoising steps. More denoising steps usually lead to a higher quality video at the The number of denoising steps. More denoising steps usually lead to a higher quality video at the
expense of slower inference. This parameter is modulated by `strength`. expense of slower inference. This parameter is modulated by `strength`.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
min_guidance_scale (`float`, *optional*, defaults to 1.0): min_guidance_scale (`float`, *optional*, defaults to 1.0):
The minimum guidance scale. Used for the classifier free guidance with first frame. The minimum guidance scale. Used for the classifier free guidance with first frame.
max_guidance_scale (`float`, *optional*, defaults to 3.0): max_guidance_scale (`float`, *optional*, defaults to 3.0):
...@@ -492,8 +557,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline): ...@@ -492,8 +557,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
added_time_ids = added_time_ids.to(device) added_time_ids = added_time_ids.to(device)
# 6. Prepare timesteps # 6. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)
timesteps = self.scheduler.timesteps
# 7. Prepare latent variables # 7. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
......
...@@ -124,6 +124,7 @@ def retrieve_timesteps( ...@@ -124,6 +124,7 @@ def retrieve_timesteps(
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None, timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -139,14 +140,18 @@ def retrieve_timesteps( ...@@ -139,14 +140,18 @@ def retrieve_timesteps(
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` `num_inference_steps` and `sigmas` must be `None`.
must be `None`. sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns: Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps: if not accepts_timesteps:
...@@ -157,6 +162,16 @@ def retrieve_timesteps( ...@@ -157,6 +162,16 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else: else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
...@@ -669,6 +684,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -669,6 +684,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
...@@ -707,6 +723,10 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -707,6 +723,10 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order. passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.5): guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen `guidance_scale` is defined as `w` of equation 2. of [Imagen
...@@ -816,7 +836,9 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -816,7 +836,9 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels num_channels_latents = self.unet.config.in_channels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment