Unverified Commit f6f7afa1 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Flux latents fix (#9929)



* update

* update

* update

* update

* update

* update

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 637e2302
...@@ -197,7 +197,9 @@ class FluxPipeline( ...@@ -197,7 +197,9 @@ class FluxPipeline(
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
) )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = ( self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
) )
...@@ -386,9 +388,9 @@ class FluxPipeline( ...@@ -386,9 +388,9 @@ class FluxPipeline(
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None, max_sequence_length=None,
): ):
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError( logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
) )
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
...@@ -451,8 +453,10 @@ class FluxPipeline( ...@@ -451,8 +453,10 @@ class FluxPipeline(
def _unpack_latents(latents, height, width, vae_scale_factor): def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = width // vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.permute(0, 3, 1, 4, 2, 5)
...@@ -501,8 +505,10 @@ class FluxPipeline( ...@@ -501,8 +505,10 @@ class FluxPipeline(
generator, generator,
latents=None, latents=None,
): ):
height = int(height) // self.vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = int(width) // self.vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
......
...@@ -218,7 +218,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -218,7 +218,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
) )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = ( self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
) )
...@@ -410,9 +412,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -410,9 +412,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None, max_sequence_length=None,
): ):
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError( logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
) )
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
...@@ -478,8 +480,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -478,8 +480,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
def _unpack_latents(latents, height, width, vae_scale_factor): def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = width // vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.permute(0, 3, 1, 4, 2, 5)
...@@ -500,8 +504,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -500,8 +504,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
generator, generator,
latents=None, latents=None,
): ):
height = int(height) // self.vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = int(width) // self.vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
......
...@@ -230,7 +230,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -230,7 +230,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
) )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = ( self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
) )
...@@ -453,9 +455,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -453,9 +455,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: if height % self.vae_scale_factor * 2 != 0 or width % self.vae_scale_factor * 2 != 0:
raise ValueError( logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
) )
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
...@@ -521,8 +523,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -521,8 +523,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
def _unpack_latents(latents, height, width, vae_scale_factor): def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = width // vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.permute(0, 3, 1, 4, 2, 5)
...@@ -551,9 +555,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -551,9 +555,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
height = int(height) // self.vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = int(width) // self.vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
...@@ -873,7 +878,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -873,7 +878,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents, latent_image_ids = self.prepare_latents( latents, latent_image_ids = self.prepare_latents(
init_image, init_image,
latent_timestep, latent_timestep,
......
...@@ -233,9 +233,11 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -233,9 +233,11 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
) )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.mask_processor = VaeImageProcessor( self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=self.vae.config.latent_channels, vae_latent_channels=self.vae.config.latent_channels,
do_normalize=False, do_normalize=False,
do_binarize=True, do_binarize=True,
...@@ -467,9 +469,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -467,9 +469,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError( logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
) )
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
...@@ -548,8 +550,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -548,8 +550,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
def _unpack_latents(latents, height, width, vae_scale_factor): def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = width // vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.permute(0, 3, 1, 4, 2, 5)
...@@ -578,9 +582,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -578,9 +582,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
height = int(height) // self.vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = int(width) // self.vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
...@@ -624,8 +629,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -624,8 +629,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
device, device,
generator, generator,
): ):
height = int(height) // self.vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = int(width) // self.vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
# resize the mask to latents shape as we concatenate the mask to the latents # resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision # and half precision
...@@ -663,7 +670,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -663,7 +670,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
# aligning device to prevent device errors when concating it with the latent model input # aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
masked_image_latents = self._pack_latents( masked_image_latents = self._pack_latents(
masked_image_latents, masked_image_latents,
batch_size, batch_size,
......
...@@ -214,7 +214,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile ...@@ -214,7 +214,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
) )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = ( self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
) )
...@@ -437,9 +439,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile ...@@ -437,9 +439,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError( logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
) )
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
...@@ -505,8 +507,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile ...@@ -505,8 +507,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
def _unpack_latents(latents, height, width, vae_scale_factor): def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = width // vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.permute(0, 3, 1, 4, 2, 5)
...@@ -534,9 +538,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile ...@@ -534,9 +538,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
height = int(height) // self.vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = int(width) // self.vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
......
...@@ -211,9 +211,11 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -211,9 +211,11 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
) )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.mask_processor = VaeImageProcessor( self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=self.vae.config.latent_channels, vae_latent_channels=self.vae.config.latent_channels,
do_normalize=False, do_normalize=False,
do_binarize=True, do_binarize=True,
...@@ -445,9 +447,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -445,9 +447,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError( logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
) )
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
...@@ -526,8 +528,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -526,8 +528,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
def _unpack_latents(latents, height, width, vae_scale_factor): def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = width // vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.permute(0, 3, 1, 4, 2, 5)
...@@ -555,9 +559,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -555,9 +559,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
height = int(height) // self.vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = int(width) // self.vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
...@@ -600,8 +605,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -600,8 +605,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
device, device,
generator, generator,
): ):
height = int(height) // self.vae_scale_factor # VAE applies 8x compression on images but we must also account for packing which requires
width = int(width) // self.vae_scale_factor # latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
# resize the mask to latents shape as we concatenate the mask to the latents # resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision # and half precision
...@@ -639,7 +646,6 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -639,7 +646,6 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
# aligning device to prevent device errors when concating it with the latent model input # aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
masked_image_latents = self._pack_latents( masked_image_latents = self._pack_latents(
masked_image_latents, masked_image_latents,
batch_size, batch_size,
......
...@@ -181,6 +181,28 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin): ...@@ -181,6 +181,28 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
def test_xformers_attention_forwardGenerator_pass(self): def test_xformers_attention_forwardGenerator_pass(self):
pass pass
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 56)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update(
{
"control_image": randn_tensor(
(1, 3, height, width),
device=torch_device,
dtype=torch.float16,
)
}
)
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
@slow @slow
@require_big_gpu_with_torch_cuda @require_big_gpu_with_torch_cuda
......
...@@ -14,6 +14,7 @@ from diffusers import ( ...@@ -14,6 +14,7 @@ from diffusers import (
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
torch_device, torch_device,
) )
from diffusers.utils.torch_utils import randn_tensor
from ..test_pipelines_common import ( from ..test_pipelines_common import (
PipelineTesterMixin, PipelineTesterMixin,
...@@ -218,3 +219,31 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi ...@@ -218,3 +219,31 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi
assert np.allclose( assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled." ), "Original outputs should match when fused QKV projections are disabled."
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 56)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update(
{
"control_image": randn_tensor(
(1, 3, height, width),
device=torch_device,
dtype=torch.float16,
),
"image": randn_tensor(
(1, 3, height, width),
device=torch_device,
dtype=torch.float16,
),
"height": height,
"width": width,
}
)
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
...@@ -23,7 +23,9 @@ from diffusers import ( ...@@ -23,7 +23,9 @@ from diffusers import (
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
torch_device,
) )
from diffusers.utils.torch_utils import randn_tensor
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin
...@@ -192,3 +194,33 @@ class FluxControlNetInpaintPipelineTests(unittest.TestCase, PipelineTesterMixin) ...@@ -192,3 +194,33 @@ class FluxControlNetInpaintPipelineTests(unittest.TestCase, PipelineTesterMixin)
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3) super().test_inference_batch_single_identical(expected_max_diff=3e-3)
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 56)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update(
{
"control_image": randn_tensor(
(1, 3, height, width),
device=torch_device,
dtype=torch.float16,
),
"image": randn_tensor(
(1, 3, height, width),
device=torch_device,
dtype=torch.float16,
),
"mask_image": torch.ones((1, 1, height, width)).to(torch_device),
"height": height,
"width": width,
}
)
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
...@@ -191,6 +191,20 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin): ...@@ -191,6 +191,20 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled." ), "Original outputs should match when fused QKV projections are disabled."
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
@slow @slow
@require_big_gpu_with_torch_cuda @require_big_gpu_with_torch_cuda
......
...@@ -147,3 +147,17 @@ class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): ...@@ -147,3 +147,17 @@ class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
max_diff = np.abs(output_with_prompt - output_with_embeds).max() max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4 assert max_diff < 1e-4
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
...@@ -149,3 +149,17 @@ class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): ...@@ -149,3 +149,17 @@ class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
max_diff = np.abs(output_with_prompt - output_with_embeds).max() max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4 assert max_diff < 1e-4
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment