"vscode:/vscode.git/clone" did not exist on "7a4324cce3f84d14afe8e5cfd47fb67701ce2fd3"
Unverified Commit 94f2c48d authored by Suprhimp's avatar Suprhimp Committed by GitHub
Browse files

[feat]Add strength in flux_fill pipeline (denoising strength for fluxfill) (#10603)

* [feat]add strength in flux_fill pipeline

* Update src/diffusers/pipelines/flux/pipeline_flux_fill.py

* Update src/diffusers/pipelines/flux/pipeline_flux_fill.py

* Update src/diffusers/pipelines/flux/pipeline_flux_fill.py

* [refactor] refactor after review

* [fix] change comment

* Apply style fixes

* empty

* fix

* update prepare_latents from flux.img2img pipeline

* style

* Update src/diffusers/pipelines/flux/pipeline_flux_fill.py

---------
parent aabf8ce2
...@@ -224,11 +224,13 @@ class FluxFillPipeline( ...@@ -224,11 +224,13 @@ class FluxFillPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
)
self.mask_processor = VaeImageProcessor( self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=latent_channels, vae_latent_channels=self.latent_channels,
do_normalize=False, do_normalize=False,
do_binarize=True, do_binarize=True,
do_convert_grayscale=True, do_convert_grayscale=True,
...@@ -493,10 +495,38 @@ class FluxFillPipeline( ...@@ -493,10 +495,38 @@ class FluxFillPipeline(
return prompt_embeds, pooled_prompt_embeds, text_ids return prompt_embeds, pooled_prompt_embeds, text_ids
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
return image_latents
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
prompt_2, prompt_2,
strength,
height, height,
width, width,
prompt_embeds=None, prompt_embeds=None,
...@@ -507,6 +537,9 @@ class FluxFillPipeline( ...@@ -507,6 +537,9 @@ class FluxFillPipeline(
mask_image=None, mask_image=None,
masked_image_latents=None, masked_image_latents=None,
): ):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning( logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
...@@ -624,9 +657,11 @@ class FluxFillPipeline( ...@@ -624,9 +657,11 @@ class FluxFillPipeline(
""" """
self.vae.disable_tiling() self.vae.disable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
def prepare_latents( def prepare_latents(
self, self,
image,
timestep,
batch_size, batch_size,
num_channels_latents, num_channels_latents,
height, height,
...@@ -636,28 +671,41 @@ class FluxFillPipeline( ...@@ -636,28 +671,41 @@ class FluxFillPipeline(
generator, generator,
latents=None, latents=None,
): ):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# VAE applies 8x compression on images but we must also account for packing which requires # VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2. # latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2)) height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
if latents is not None: if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size: image = image.to(device=device, dtype=dtype)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
else:
image_latents = torch.cat([image_latents], dim=0)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids return latents, latent_image_ids
@property @property
...@@ -687,6 +735,7 @@ class FluxFillPipeline( ...@@ -687,6 +735,7 @@ class FluxFillPipeline(
masked_image_latents: Optional[torch.FloatTensor] = None, masked_image_latents: Optional[torch.FloatTensor] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
strength: float = 1.0,
num_inference_steps: int = 50, num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None, sigmas: Optional[List[float]] = None,
guidance_scale: float = 30.0, guidance_scale: float = 30.0,
...@@ -731,6 +780,12 @@ class FluxFillPipeline( ...@@ -731,6 +780,12 @@ class FluxFillPipeline(
The height in pixels of the generated image. This is set to 1024 by default for the best results. The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results. The width in pixels of the generated image. This is set to 1024 by default for the best results.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. expense of slower inference.
...@@ -794,6 +849,7 @@ class FluxFillPipeline( ...@@ -794,6 +849,7 @@ class FluxFillPipeline(
self.check_inputs( self.check_inputs(
prompt, prompt,
prompt_2, prompt_2,
strength,
height, height,
width, width,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
...@@ -809,6 +865,9 @@ class FluxFillPipeline( ...@@ -809,6 +865,9 @@ class FluxFillPipeline(
self._joint_attention_kwargs = joint_attention_kwargs self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False self._interrupt = False
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
# 2. Define call parameters # 2. Define call parameters
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -838,9 +897,37 @@ class FluxFillPipeline( ...@@ -838,9 +897,37 @@ class FluxFillPipeline(
lora_scale=lora_scale, lora_scale=lora_scale,
) )
# 4. Prepare latent variables # 4. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
if num_inference_steps < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 5. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels num_channels_latents = self.vae.config.latent_channels
latents, latent_image_ids = self.prepare_latents( latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
height, height,
...@@ -851,17 +938,16 @@ class FluxFillPipeline( ...@@ -851,17 +938,16 @@ class FluxFillPipeline(
latents, latents,
) )
# 5. Prepare mask and masked image latents # 6. Prepare mask and masked image latents
if masked_image_latents is not None: if masked_image_latents is not None:
masked_image_latents = masked_image_latents.to(latents.device) masked_image_latents = masked_image_latents.to(latents.device)
else: else:
image = self.image_processor.preprocess(image, height=height, width=width)
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width) mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
masked_image = image * (1 - mask_image) masked_image = init_image * (1 - mask_image)
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype) masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
height, width = image.shape[-2:] height, width = init_image.shape[-2:]
mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image_latents = self.prepare_mask_latents(
mask_image, mask_image,
masked_image, masked_image,
...@@ -876,23 +962,6 @@ class FluxFillPipeline( ...@@ -876,23 +962,6 @@ class FluxFillPipeline(
) )
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1) masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
# 6. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment