Unverified Commit 9f2d5c9e authored by hlky's avatar hlky Committed by GitHub
Browse files

Flux with Remote Encode (#11091)

* Flux img2img remote encode

* Flux inpaint

* -copied from
parent dc62e693
...@@ -533,7 +533,6 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin ...@@ -533,7 +533,6 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
return latents return latents
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
def prepare_latents( def prepare_latents(
self, self,
image, image,
......
...@@ -533,7 +533,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -533,7 +533,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
return latents return latents
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
def prepare_latents( def prepare_latents(
self, self,
image, image,
......
...@@ -561,7 +561,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -561,7 +561,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
return latents return latents
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents
def prepare_latents( def prepare_latents(
self, self,
image, image,
...@@ -614,7 +613,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -614,7 +613,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, noise, image_latents, latent_image_ids return latents, noise, image_latents, latent_image_ids
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
def prepare_mask_latents( def prepare_mask_latents(
self, self,
mask, mask,
......
...@@ -225,7 +225,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile ...@@ -225,7 +225,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
)
self.tokenizer_max_length = ( self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
) )
...@@ -634,7 +637,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile ...@@ -634,7 +637,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
return latents.to(device=device, dtype=dtype), latent_image_ids return latents.to(device=device, dtype=dtype), latent_image_ids
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator) if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size # expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0] additional_image_per_prompt = batch_size // image_latents.shape[0]
......
...@@ -222,11 +222,13 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM ...@@ -222,11 +222,13 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
)
self.mask_processor = VaeImageProcessor( self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=latent_channels, vae_latent_channels=self.latent_channels,
do_normalize=False, do_normalize=False,
do_binarize=True, do_binarize=True,
do_convert_grayscale=True, do_convert_grayscale=True,
...@@ -653,7 +655,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM ...@@ -653,7 +655,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator) if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size # expand init_latents for batch_size
...@@ -710,7 +715,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM ...@@ -710,7 +715,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
else: else:
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor masked_image_latents = (
masked_image_latents - self.vae.config.shift_factor
) * self.vae.config.scaling_factor
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size: if mask.shape[0] < batch_size:
......
...@@ -367,7 +367,7 @@ def prepare_encode( ...@@ -367,7 +367,7 @@ def prepare_encode(
if shift_factor is not None: if shift_factor is not None:
parameters["shift_factor"] = shift_factor parameters["shift_factor"] = shift_factor
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
data = safetensors.torch._tobytes(image, "tensor") data = safetensors.torch._tobytes(image.contiguous(), "tensor")
parameters["shape"] = list(image.shape) parameters["shape"] = list(image.shape)
parameters["dtype"] = str(image.dtype).split(".")[-1] parameters["dtype"] = str(image.dtype).split(".")[-1]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment