Unverified Commit cd6ca9df authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Fix prepare latent image ids and vae sample generators for flux (#9981)

* fix

* update expected slice
parent e564abe2
...@@ -513,7 +513,7 @@ class FluxPipeline( ...@@ -513,7 +513,7 @@ class FluxPipeline(
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
if latents is not None: if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
......
...@@ -97,6 +97,20 @@ def calculate_shift( ...@@ -97,6 +97,20 @@ def calculate_shift(
return mu return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps( def retrieve_timesteps(
scheduler, scheduler,
...@@ -512,7 +526,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -512,7 +526,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
shape = (batch_size, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, height, width)
if latents is not None: if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
...@@ -772,7 +786,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -772,7 +786,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
if self.controlnet.input_hint_block is None: if self.controlnet.input_hint_block is None:
# vae encode # vae encode
control_image = self.vae.encode(control_image).latent_dist.sample() control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack # pack
...@@ -810,7 +824,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -810,7 +824,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
if self.controlnet.nets[0].input_hint_block is None: if self.controlnet.nets[0].input_hint_block is None:
# vae encode # vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample() control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack # pack
......
...@@ -801,7 +801,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -801,7 +801,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
) )
height, width = control_image.shape[-2:] height, width = control_image.shape[-2:]
control_image = self.vae.encode(control_image).latent_dist.sample() control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
height_control_image, width_control_image = control_image.shape[2:] height_control_image, width_control_image = control_image.shape[2:]
...@@ -832,7 +832,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -832,7 +832,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
) )
height, width = control_image_.shape[-2:] height, width = control_image_.shape[-2:]
control_image_ = self.vae.encode(control_image_).latent_dist.sample() control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
height_control_image, width_control_image = control_image_.shape[2:] height_control_image, width_control_image = control_image_.shape[2:]
......
...@@ -942,7 +942,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -942,7 +942,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
if self.controlnet.input_hint_block is None: if self.controlnet.input_hint_block is None:
# vae encode # vae encode
control_image = self.vae.encode(control_image).latent_dist.sample() control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack # pack
...@@ -979,7 +979,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -979,7 +979,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
if self.controlnet.nets[0].input_hint_block is None: if self.controlnet.nets[0].input_hint_block is None:
# vae encode # vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample() control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack # pack
......
...@@ -170,7 +170,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin): ...@@ -170,7 +170,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
assert image.shape == (1, 32, 32, 3) assert image.shape == (1, 32, 32, 3)
expected_slice = np.array( expected_slice = np.array(
[0.7348633, 0.41333008, 0.6621094, 0.5444336, 0.47607422, 0.5859375, 0.44677734, 0.4506836, 0.40454102] [0.47387695, 0.63134766, 0.5605469, 0.61621094, 0.7207031, 0.7089844, 0.70410156, 0.6113281, 0.64160156]
) )
assert ( assert (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment