"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "723933f5f18ffe6889f17e08eb3fa0866b27f494"
Unverified Commit 86bd991e authored by v2ray's avatar v2ray Committed by GitHub
Browse files

Fixed noise_pred_text referenced before assignment. (#9537)

* Fixed local variable noise_pred_text referenced before assignment when using PAG with guidance scale and guidance rescale at the same time.

* Fixed style.

* Made returning text pred noise an argument.
parent 02eeb8e7
...@@ -98,7 +98,9 @@ class PAGMixin: ...@@ -98,7 +98,9 @@ class PAGMixin:
else: else:
return self.pag_scale return self.pag_scale
def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): def _apply_perturbed_attention_guidance(
self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False
):
r""" r"""
Apply perturbed attention guidance to the noise prediction. Apply perturbed attention guidance to the noise prediction.
...@@ -107,9 +109,11 @@ class PAGMixin: ...@@ -107,9 +109,11 @@ class PAGMixin:
do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
guidance_scale (float): The scale factor for the guidance term. guidance_scale (float): The scale factor for the guidance term.
t (int): The current time step. t (int): The current time step.
return_pred_text (bool): Whether to return the text noise prediction.
Returns: Returns:
torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance. Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying
perturbed attention guidance and the text noise prediction.
""" """
pag_scale = self._get_pag_scale(t) pag_scale = self._get_pag_scale(t)
if do_classifier_free_guidance: if do_classifier_free_guidance:
...@@ -122,6 +126,8 @@ class PAGMixin: ...@@ -122,6 +126,8 @@ class PAGMixin:
else: else:
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
if return_pred_text:
return noise_pred, noise_pred_text
return noise_pred return noise_pred
def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
......
...@@ -893,8 +893,8 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -893,8 +893,8 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
# perform guidance # perform guidance
if self.do_perturbed_attention_guidance: if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance( noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
) )
elif self.do_classifier_free_guidance: elif self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
......
...@@ -993,8 +993,8 @@ class StableDiffusionPAGPipeline( ...@@ -993,8 +993,8 @@ class StableDiffusionPAGPipeline(
# perform guidance # perform guidance
if self.do_perturbed_attention_guidance: if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance( noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
) )
elif self.do_classifier_free_guidance: elif self.do_classifier_free_guidance:
......
...@@ -1237,8 +1237,8 @@ class StableDiffusionXLPAGPipeline( ...@@ -1237,8 +1237,8 @@ class StableDiffusionXLPAGPipeline(
# perform guidance # perform guidance
if self.do_perturbed_attention_guidance: if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance( noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
) )
elif self.do_classifier_free_guidance: elif self.do_classifier_free_guidance:
......
...@@ -1437,8 +1437,8 @@ class StableDiffusionXLPAGImg2ImgPipeline( ...@@ -1437,8 +1437,8 @@ class StableDiffusionXLPAGImg2ImgPipeline(
# perform guidance # perform guidance
if self.do_perturbed_attention_guidance: if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance( noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
) )
elif self.do_classifier_free_guidance: elif self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
......
...@@ -1649,8 +1649,8 @@ class StableDiffusionXLPAGInpaintPipeline( ...@@ -1649,8 +1649,8 @@ class StableDiffusionXLPAGInpaintPipeline(
# perform guidance # perform guidance
if self.do_perturbed_attention_guidance: if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance( noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
) )
elif self.do_classifier_free_guidance: elif self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment