Unverified Commit 73b59f52 authored by Ina's avatar Ina Committed by GitHub
Browse files

[refactor] enhance readability of flux related pipelines (#9711)

* flux pipline: readability enhancement.
parent 52d44498
...@@ -2198,8 +2198,8 @@ def main(args): ...@@ -2198,8 +2198,8 @@ def main(args):
latent_image_ids = FluxPipeline._prepare_latent_image_ids( latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0], model_input.shape[0],
model_input.shape[2], model_input.shape[2] // 2,
model_input.shape[3], model_input.shape[3] // 2,
accelerator.device, accelerator.device,
weight_dtype, weight_dtype,
) )
...@@ -2253,8 +2253,8 @@ def main(args): ...@@ -2253,8 +2253,8 @@ def main(args):
)[0] )[0]
model_pred = FluxPipeline._unpack_latents( model_pred = FluxPipeline._unpack_latents(
model_pred, model_pred,
height=int(model_input.shape[2] * vae_scale_factor / 2), height=model_input.shape[2] * vae_scale_factor,
width=int(model_input.shape[3] * vae_scale_factor / 2), width=model_input.shape[3] * vae_scale_factor,
vae_scale_factor=vae_scale_factor, vae_scale_factor=vae_scale_factor,
) )
......
...@@ -1256,8 +1256,8 @@ def main(args): ...@@ -1256,8 +1256,8 @@ def main(args):
latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids( latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids(
batch_size=pixel_latents_tmp.shape[0], batch_size=pixel_latents_tmp.shape[0],
height=pixel_latents_tmp.shape[2], height=pixel_latents_tmp.shape[2] // 2,
width=pixel_latents_tmp.shape[3], width=pixel_latents_tmp.shape[3] // 2,
device=pixel_values.device, device=pixel_values.device,
dtype=pixel_values.dtype, dtype=pixel_values.dtype,
) )
......
...@@ -1540,12 +1540,12 @@ def main(args): ...@@ -1540,12 +1540,12 @@ def main(args):
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype) model_input = model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
latent_image_ids = FluxPipeline._prepare_latent_image_ids( latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0], model_input.shape[0],
model_input.shape[2], model_input.shape[2] // 2,
model_input.shape[3], model_input.shape[3] // 2,
accelerator.device, accelerator.device,
weight_dtype, weight_dtype,
) )
...@@ -1601,8 +1601,8 @@ def main(args): ...@@ -1601,8 +1601,8 @@ def main(args):
# upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042 # upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042
model_pred = FluxPipeline._unpack_latents( model_pred = FluxPipeline._unpack_latents(
model_pred, model_pred,
height=int(model_input.shape[2] * vae_scale_factor / 2), height=model_input.shape[2] * vae_scale_factor,
width=int(model_input.shape[3] * vae_scale_factor / 2), width=model_input.shape[3] * vae_scale_factor,
vae_scale_factor=vae_scale_factor, vae_scale_factor=vae_scale_factor,
) )
......
...@@ -1645,12 +1645,12 @@ def main(args): ...@@ -1645,12 +1645,12 @@ def main(args):
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype) model_input = model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
latent_image_ids = FluxPipeline._prepare_latent_image_ids( latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0], model_input.shape[0],
model_input.shape[2], model_input.shape[2] // 2,
model_input.shape[3], model_input.shape[3] // 2,
accelerator.device, accelerator.device,
weight_dtype, weight_dtype,
) )
...@@ -1704,8 +1704,8 @@ def main(args): ...@@ -1704,8 +1704,8 @@ def main(args):
)[0] )[0]
model_pred = FluxPipeline._unpack_latents( model_pred = FluxPipeline._unpack_latents(
model_pred, model_pred,
height=int(model_input.shape[2] * vae_scale_factor / 2), height=model_input.shape[2] * vae_scale_factor,
width=int(model_input.shape[3] * vae_scale_factor / 2), width=model_input.shape[3] * vae_scale_factor,
vae_scale_factor=vae_scale_factor, vae_scale_factor=vae_scale_factor,
) )
......
...@@ -195,13 +195,13 @@ class FluxPipeline( ...@@ -195,13 +195,13 @@ class FluxPipeline(
scheduler=scheduler, scheduler=scheduler,
) )
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
) )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = ( self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
) )
self.default_sample_size = 64 self.default_sample_size = 128
def _get_t5_prompt_embeds( def _get_t5_prompt_embeds(
self, self,
...@@ -386,8 +386,10 @@ class FluxPipeline( ...@@ -386,8 +386,10 @@ class FluxPipeline(
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None, max_sequence_length=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
)
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
...@@ -425,9 +427,9 @@ class FluxPipeline( ...@@ -425,9 +427,9 @@ class FluxPipeline(
@staticmethod @staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3) latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
...@@ -452,10 +454,10 @@ class FluxPipeline( ...@@ -452,10 +454,10 @@ class FluxPipeline(
height = height // vae_scale_factor height = height // vae_scale_factor
width = width // vae_scale_factor width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents return latents
...@@ -499,8 +501,8 @@ class FluxPipeline( ...@@ -499,8 +501,8 @@ class FluxPipeline(
generator, generator,
latents=None, latents=None,
): ):
height = 2 * (int(height) // self.vae_scale_factor) height = int(height) // self.vae_scale_factor
width = 2 * (int(width) // self.vae_scale_factor) width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
...@@ -517,7 +519,7 @@ class FluxPipeline( ...@@ -517,7 +519,7 @@ class FluxPipeline(
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids return latents, latent_image_ids
......
...@@ -216,13 +216,13 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -216,13 +216,13 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
controlnet=controlnet, controlnet=controlnet,
) )
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
) )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = ( self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
) )
self.default_sample_size = 64 self.default_sample_size = 128
def _get_t5_prompt_embeds( def _get_t5_prompt_embeds(
self, self,
...@@ -410,8 +410,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -410,8 +410,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None, max_sequence_length=None,
): ):
if height % 8 != 0 or width % 8 != 0: if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
)
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
...@@ -450,9 +452,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -450,9 +452,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3) latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
...@@ -479,10 +481,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -479,10 +481,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
height = height // vae_scale_factor height = height // vae_scale_factor
width = width // vae_scale_factor width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents return latents
...@@ -498,8 +500,8 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -498,8 +500,8 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
generator, generator,
latents=None, latents=None,
): ):
height = 2 * (int(height) // self.vae_scale_factor) height = int(height) // self.vae_scale_factor
width = 2 * (int(width) // self.vae_scale_factor) width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
...@@ -516,7 +518,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -516,7 +518,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents, latent_image_ids return latents, latent_image_ids
......
...@@ -228,13 +228,13 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -228,13 +228,13 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
controlnet=controlnet, controlnet=controlnet,
) )
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
) )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = ( self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
) )
self.default_sample_size = 64 self.default_sample_size = 128
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds( def _get_t5_prompt_embeds(
...@@ -453,8 +453,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -453,8 +453,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0: if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
)
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
...@@ -493,9 +495,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -493,9 +495,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3) latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
...@@ -522,10 +524,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -522,10 +524,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
height = height // vae_scale_factor height = height // vae_scale_factor
width = width // vae_scale_factor width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents return latents
...@@ -549,11 +551,11 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -549,11 +551,11 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
height = 2 * (int(height) // self.vae_scale_factor) height = int(height) // self.vae_scale_factor
width = 2 * (int(width) // self.vae_scale_factor) width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
if latents is not None: if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids return latents.to(device=device, dtype=dtype), latent_image_ids
...@@ -852,7 +854,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -852,7 +854,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
control_mode = control_mode.reshape([-1, 1]) control_mode = control_mode.reshape([-1, 1])
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.base_image_seq_len,
......
...@@ -231,7 +231,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -231,7 +231,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
) )
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
) )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor( self.mask_processor = VaeImageProcessor(
...@@ -244,7 +244,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -244,7 +244,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
self.tokenizer_max_length = ( self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
) )
self.default_sample_size = 64 self.default_sample_size = 128
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds( def _get_t5_prompt_embeds(
...@@ -467,8 +467,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -467,8 +467,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0: if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
)
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
...@@ -520,9 +522,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -520,9 +522,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3) latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
...@@ -549,10 +551,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -549,10 +551,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
height = height // vae_scale_factor height = height // vae_scale_factor
width = width // vae_scale_factor width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents return latents
...@@ -576,11 +578,11 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -576,11 +578,11 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
height = 2 * (int(height) // self.vae_scale_factor) height = int(height) // self.vae_scale_factor
width = 2 * (int(width) // self.vae_scale_factor) width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator) image_latents = self._encode_vae_image(image=image, generator=generator)
...@@ -622,8 +624,8 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -622,8 +624,8 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
device, device,
generator, generator,
): ):
height = 2 * (int(height) // self.vae_scale_factor) height = int(height) // self.vae_scale_factor
width = 2 * (int(width) // self.vae_scale_factor) width = int(width) // self.vae_scale_factor
# resize the mask to latents shape as we concatenate the mask to the latents # resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision # and half precision
...@@ -996,7 +998,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -996,7 +998,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
# 6. Prepare timesteps # 6. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor) image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * (
int(global_width) // self.vae_scale_factor // 2
)
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.base_image_seq_len,
......
...@@ -212,13 +212,13 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -212,13 +212,13 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
scheduler=scheduler, scheduler=scheduler,
) )
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
) )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = ( self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
) )
self.default_sample_size = 64 self.default_sample_size = 128
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds( def _get_t5_prompt_embeds(
...@@ -437,8 +437,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -437,8 +437,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0: if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
)
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
...@@ -477,9 +479,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -477,9 +479,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3) latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
...@@ -506,10 +508,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -506,10 +508,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
height = height // vae_scale_factor height = height // vae_scale_factor
width = width // vae_scale_factor width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents return latents
...@@ -532,11 +534,11 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -532,11 +534,11 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
height = 2 * (int(height) // self.vae_scale_factor) height = int(height) // self.vae_scale_factor
width = 2 * (int(width) // self.vae_scale_factor) width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
if latents is not None: if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids return latents.to(device=device, dtype=dtype), latent_image_ids
...@@ -736,7 +738,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -736,7 +738,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
# 4.Prepare timesteps # 4.Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.base_image_seq_len,
......
...@@ -209,7 +209,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -209,7 +209,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
scheduler=scheduler, scheduler=scheduler,
) )
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
) )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor( self.mask_processor = VaeImageProcessor(
...@@ -222,7 +222,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -222,7 +222,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
self.tokenizer_max_length = ( self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
) )
self.default_sample_size = 64 self.default_sample_size = 128
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds( def _get_t5_prompt_embeds(
...@@ -445,8 +445,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -445,8 +445,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
if strength < 0 or strength > 1: if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0: if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}."
)
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
...@@ -498,9 +500,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -498,9 +500,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
@staticmethod @staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3) latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
...@@ -527,10 +529,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -527,10 +529,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
height = height // vae_scale_factor height = height // vae_scale_factor
width = width // vae_scale_factor width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
return latents return latents
...@@ -553,11 +555,11 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -553,11 +555,11 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
height = 2 * (int(height) // self.vae_scale_factor) height = int(height) // self.vae_scale_factor
width = 2 * (int(width) // self.vae_scale_factor) width = int(width) // self.vae_scale_factor
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator) image_latents = self._encode_vae_image(image=image, generator=generator)
...@@ -598,8 +600,8 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -598,8 +600,8 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
device, device,
generator, generator,
): ):
height = 2 * (int(height) // self.vae_scale_factor) height = int(height) // self.vae_scale_factor
width = 2 * (int(width) // self.vae_scale_factor) width = int(width) // self.vae_scale_factor
# resize the mask to latents shape as we concatenate the mask to the latents # resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision # and half precision
...@@ -866,7 +868,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -866,7 +868,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
# 4.Prepare timesteps # 4.Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
self.scheduler.config.base_image_seq_len, self.scheduler.config.base_image_seq_len,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment