Unverified Commit a7d50524 authored by hlky's avatar hlky Committed by GitHub
Browse files

Add `dynamic_shifting` to SD3 (#10236)

* Add `dynamic_shifting` to SD3

* calculate_shift

* FlowMatchHeunDiscreteScheduler doesn't support mu

* Inpaint/img2img
parent 672bd495
...@@ -68,6 +68,20 @@ EXAMPLE_DOC_STRING = """ ...@@ -68,6 +68,20 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps( def retrieve_timesteps(
scheduler, scheduler,
...@@ -702,6 +716,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -702,6 +716,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
skip_layer_guidance_scale: int = 2.8, skip_layer_guidance_scale: int = 2.8,
skip_layer_guidance_stop: int = 0.2, skip_layer_guidance_stop: int = 0.2,
skip_layer_guidance_start: int = 0.01, skip_layer_guidance_start: int = 0.01,
mu: Optional[float] = None,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -802,6 +817,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -802,6 +817,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
`skip_guidance_layers` will start. The guidance will be applied to the layers specified in `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
`skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
StabiltyAI for Stable Diffusion 3.5 Medium is 0.01. StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
Examples: Examples:
...@@ -882,12 +898,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -882,12 +898,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 4. Prepare timesteps # 4. Prepare latent variables
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
...@@ -900,6 +911,33 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle ...@@ -900,6 +911,33 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
latents, latents,
) )
# 5. Prepare timesteps
scheduler_kwargs = {}
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
_, _, height, width = latents.shape
image_seq_len = (height // self.transformer.config.patch_size) * (
width // self.transformer.config.patch_size
)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
scheduler_kwargs["mu"] = mu
elif mu is not None:
scheduler_kwargs["mu"] = mu
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
**scheduler_kwargs,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Denoising loop # 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
......
...@@ -75,6 +75,20 @@ EXAMPLE_DOC_STRING = """ ...@@ -75,6 +75,20 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents( def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
...@@ -748,6 +762,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -748,6 +762,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256, max_sequence_length: int = 256,
mu: Optional[float] = None,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -832,6 +847,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -832,6 +847,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
Examples: Examples:
...@@ -913,7 +929,24 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -913,7 +929,24 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
image = self.image_processor.preprocess(image, height=height, width=width) image = self.image_processor.preprocess(image, height=height, width=width)
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) scheduler_kwargs = {}
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
int(width) // self.vae_scale_factor // self.transformer.config.patch_size
)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
scheduler_kwargs["mu"] = mu
elif mu is not None:
scheduler_kwargs["mu"] = mu
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
......
...@@ -74,6 +74,20 @@ EXAMPLE_DOC_STRING = """ ...@@ -74,6 +74,20 @@ EXAMPLE_DOC_STRING = """
""" """
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents( def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
...@@ -838,6 +852,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -838,6 +852,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256, max_sequence_length: int = 256,
mu: Optional[float] = None,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -947,6 +962,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -947,6 +962,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
Examples: Examples:
...@@ -1023,7 +1039,24 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro ...@@ -1023,7 +1039,24 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 3. Prepare timesteps # 3. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) scheduler_kwargs = {}
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
int(width) // self.vae_scale_factor // self.transformer.config.patch_size
)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
scheduler_kwargs["mu"] = mu
elif mu is not None:
scheduler_kwargs["mu"] = mu
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
# check that number of inference steps is not < 1 - as this doesn't make sense # check that number of inference steps is not < 1 - as this doesn't make sense
if num_inference_steps < 1: if num_inference_steps < 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment