Unverified Commit 16a3dad4 authored by Sangwon Lee's avatar Sangwon Lee Committed by GitHub
Browse files

Fix StableDiffusionXLPAGInpaintPipeline (#9128)

parent 21682bab
...@@ -955,7 +955,8 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -955,7 +955,8 @@ class AutoPipelineForInpainting(ConfigMixin):
if "enable_pag" in kwargs: if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag") enable_pag = kwargs.pop("enable_pag")
if enable_pag: if enable_pag:
orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline") to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace)
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
......
...@@ -1471,6 +1471,14 @@ class StableDiffusionXLPAGInpaintPipeline( ...@@ -1471,6 +1471,14 @@ class StableDiffusionXLPAGInpaintPipeline(
generator, generator,
self.do_classifier_free_guidance, self.do_classifier_free_guidance,
) )
if self.do_perturbed_attention_guidance:
if self.do_classifier_free_guidance:
mask, _ = mask.chunk(2)
masked_image_latents, _ = masked_image_latents.chunk(2)
mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance)
masked_image_latents = self._prepare_perturbed_attention_guidance(
masked_image_latents, masked_image_latents, self.do_classifier_free_guidance
)
# 8. Check that sizes of mask, masked image and latents match # 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9: if num_channels_unet == 9:
...@@ -1659,10 +1667,10 @@ class StableDiffusionXLPAGInpaintPipeline( ...@@ -1659,10 +1667,10 @@ class StableDiffusionXLPAGInpaintPipeline(
if num_channels_unet == 4: if num_channels_unet == 4:
init_latents_proper = image_latents init_latents_proper = image_latents
if self.do_classifier_free_guidance: if self.do_perturbed_attention_guidance:
init_mask, _ = mask.chunk(2) init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2)
else: else:
init_mask = mask init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask
if i < len(timesteps) - 1: if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1] noise_timestep = timesteps[i + 1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment