pag_utils.py 9.98 KB
Newer Older
YiYi Xu's avatar
YiYi Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
17
import re
from typing import Dict, List, Tuple, Union

YiYi Xu's avatar
YiYi Xu committed
18
import torch
19
import torch.nn as nn
YiYi Xu's avatar
YiYi Xu committed
20
21

from ...models.attention_processor import (
22
23
    Attention,
    AttentionProcessor,
YiYi Xu's avatar
YiYi Xu committed
24
25
26
27
28
29
30
31
32
33
    PAGCFGIdentitySelfAttnProcessor2_0,
    PAGIdentitySelfAttnProcessor2_0,
)
from ...utils import logging


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


class PAGMixin:
34
    r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1)."""
YiYi Xu's avatar
YiYi Xu committed
35
36
37
38
39

    def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
        r"""
        Set the attention processor for the PAG layers.
        """
40
41
42
43
44
45
46
47
48
49
        pag_attn_processors = self._pag_attn_processors
        if pag_attn_processors is None:
            raise ValueError(
                "No PAG attention processors have been set. Set the attention processors by calling `set_pag_applied_layers` and passing the relevant parameters."
            )

        pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1]

        if hasattr(self, "unet"):
            model: nn.Module = self.unet
YiYi Xu's avatar
YiYi Xu committed
50
        else:
51
            model: nn.Module = self.transformer
YiYi Xu's avatar
YiYi Xu committed
52

53
        def is_self_attn(module: nn.Module) -> bool:
YiYi Xu's avatar
YiYi Xu committed
54
55
56
            r"""
            Check if the module is self-attention module based on its name.
            """
57
            return isinstance(module, Attention) and not module.is_cross_attention
YiYi Xu's avatar
YiYi Xu committed
58

59
60
61
62
        def is_fake_integral_match(layer_id, name):
            layer_id = layer_id.split(".")[-1]
            name = name.split(".")[-1]
            return layer_id.isnumeric() and name.isnumeric() and layer_id == name
YiYi Xu's avatar
YiYi Xu committed
63

64
        for layer_id in pag_applied_layers:
YiYi Xu's avatar
YiYi Xu committed
65
66
67
            # for each PAG layer input, we find corresponding self-attention layers in the unet model
            target_modules = []

68
69
70
71
72
73
74
75
76
77
78
79
80
            for name, module in model.named_modules():
                # Identify the following simple cases:
                #   (1) Self Attention layer existing
                #   (2) Whether the module name matches pag layer id even partially
                #   (3) Make sure it's not a fake integral match if the layer_id ends with a number
                #       For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1"
                if (
                    is_self_attn(module)
                    and re.search(layer_id, name) is not None
                    and not is_fake_integral_match(layer_id, name)
                ):
                    logger.debug(f"Applying PAG to layer: {name}")
                    target_modules.append(module)
YiYi Xu's avatar
YiYi Xu committed
81
82

            if len(target_modules) == 0:
83
                raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}")
YiYi Xu's avatar
YiYi Xu committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

            for module in target_modules:
                module.processor = pag_attn_proc

    def _get_pag_scale(self, t):
        r"""
        Get the scale factor for the perturbed attention guidance at timestep `t`.
        """

        if self.do_pag_adaptive_scaling:
            signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t)
            if signal_scale < 0:
                signal_scale = 0
            return signal_scale
        else:
            return self.pag_scale

101
102
103
    def _apply_perturbed_attention_guidance(
        self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False
    ):
YiYi Xu's avatar
YiYi Xu committed
104
105
106
107
108
109
110
111
        r"""
        Apply perturbed attention guidance to the noise prediction.

        Args:
            noise_pred (torch.Tensor): The noise prediction tensor.
            do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
            guidance_scale (float): The scale factor for the guidance term.
            t (int): The current time step.
112
            return_pred_text (bool): Whether to return the text noise prediction.
YiYi Xu's avatar
YiYi Xu committed
113
114

        Returns:
115
116
            Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying
            perturbed attention guidance and the text noise prediction.
YiYi Xu's avatar
YiYi Xu committed
117
118
119
120
121
122
123
124
125
126
127
128
        """
        pag_scale = self._get_pag_scale(t)
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
            noise_pred = (
                noise_pred_uncond
                + guidance_scale * (noise_pred_text - noise_pred_uncond)
                + pag_scale * (noise_pred_text - noise_pred_perturb)
            )
        else:
            noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
            noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
129
130
        if return_pred_text:
            return noise_pred, noise_pred_text
YiYi Xu's avatar
YiYi Xu committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        return noise_pred

    def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
        """
        Prepares the perturbed attention guidance for the PAG model.

        Args:
            cond (torch.Tensor): The conditional input tensor.
            uncond (torch.Tensor): The unconditional input tensor.
            do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance.

        Returns:
            torch.Tensor: The prepared perturbed attention guidance tensor.
        """

        cond = torch.cat([cond] * 2, dim=0)

        if do_classifier_free_guidance:
            cond = torch.cat([uncond, cond], dim=0)
        return cond

152
153
154
155
156
157
158
159
    def set_pag_applied_layers(
        self,
        pag_applied_layers: Union[str, List[str]],
        pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = (
            PAGCFGIdentitySelfAttnProcessor2_0(),
            PAGIdentitySelfAttnProcessor2_0(),
        ),
    ):
YiYi Xu's avatar
YiYi Xu committed
160
        r"""
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.

        Args:
            pag_applied_layers (`str` or `List[str]`):
                One or more strings identifying the layer names, or a simple regex for matching multiple layers, where
                PAG is to be applied. A few ways of expected usage are as follows:
                  - Single layers specified as - "blocks.{layer_index}"
                  - Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...]
                  - Multiple layers as a block name - "mid"
                  - Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})"
            pag_attn_processors:
                (`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(),
                PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention
                processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second
                attention processor is for PAG with CFG disabled (unconditional only).
        """

        if not hasattr(self, "_pag_attn_processors"):
            self._pag_attn_processors = None
YiYi Xu's avatar
YiYi Xu committed
180
181
182

        if not isinstance(pag_applied_layers, list):
            pag_applied_layers = [pag_applied_layers]
183
184
185
        if pag_attn_processors is not None:
            if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2:
                raise ValueError("Expected a tuple of two attention processors")
YiYi Xu's avatar
YiYi Xu committed
186

187
188
189
190
191
        for i in range(len(pag_applied_layers)):
            if not isinstance(pag_applied_layers[i], str):
                raise ValueError(
                    f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}"
                )
YiYi Xu's avatar
YiYi Xu committed
192
193

        self.pag_applied_layers = pag_applied_layers
194
        self._pag_attn_processors = pag_attn_processors
YiYi Xu's avatar
YiYi Xu committed
195
196

    @property
197
198
    def pag_scale(self) -> float:
        r"""Get the scale factor for the perturbed attention guidance."""
YiYi Xu's avatar
YiYi Xu committed
199
200
201
        return self._pag_scale

    @property
202
203
    def pag_adaptive_scale(self) -> float:
        r"""Get the adaptive scale factor for the perturbed attention guidance."""
YiYi Xu's avatar
YiYi Xu committed
204
205
206
        return self._pag_adaptive_scale

    @property
207
208
    def do_pag_adaptive_scaling(self) -> bool:
        r"""Check if the adaptive scaling is enabled for the perturbed attention guidance."""
YiYi Xu's avatar
YiYi Xu committed
209
210
211
        return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0

    @property
212
213
    def do_perturbed_attention_guidance(self) -> bool:
        r"""Check if the perturbed attention guidance is enabled."""
YiYi Xu's avatar
YiYi Xu committed
214
215
216
        return self._pag_scale > 0 and len(self.pag_applied_layers) > 0

    @property
217
    def pag_attn_processors(self) -> Dict[str, AttentionProcessor]:
YiYi Xu's avatar
YiYi Xu committed
218
219
220
221
222
223
        r"""
        Returns:
            `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
            with the key as the name of the layer.
        """

224
225
        if self._pag_attn_processors is None:
            return {}
226

227
        valid_attn_processors = {x.__class__ for x in self._pag_attn_processors}
228

229
230
231
232
233
234
235
        processors = {}
        # We could have iterated through the self.components.items() and checked if a component is
        # `ModelMixin` subclassed but that can include a VAE too.
        if hasattr(self, "unet"):
            denoiser_module = self.unet
        elif hasattr(self, "transformer"):
            denoiser_module = self.transformer
236
        else:
237
            raise ValueError("No denoiser module found.")
238

239
240
        for name, proc in denoiser_module.attn_processors.items():
            if proc.__class__ in valid_attn_processors:
241
                processors[name] = proc
242

243
        return processors