pag_utils.py 18.6 KB
Newer Older
YiYi Xu's avatar
YiYi Xu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from ...models.attention_processor import (
    PAGCFGIdentitySelfAttnProcessor2_0,
    PAGIdentitySelfAttnProcessor2_0,
)
from ...utils import logging


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


class PAGMixin:
    r"""Mixin class for PAG."""

    @staticmethod
    def _check_input_pag_applied_layer(layer):
        r"""
        Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats:
        "{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type`
        can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be
Aryan's avatar
Aryan committed
36
        in the format of "attentions_{j}". `motion_modules_index` should be in the format of "motion_modules_{j}"
YiYi Xu's avatar
YiYi Xu committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        """

        layer_splits = layer.split(".")

        if len(layer_splits) > 3:
            raise ValueError(f"pag layer should only contains block_type, block_index and attention_index{layer}.")

        if len(layer_splits) >= 1:
            if layer_splits[0] not in ["down", "mid", "up"]:
                raise ValueError(
                    f"Invalid block_type in pag layer {layer}. Accept 'down', 'mid', 'up', got {layer_splits[0]}"
                )

        if len(layer_splits) >= 2:
            if not layer_splits[1].startswith("block_"):
                raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'")

        if len(layer_splits) == 3:
Aryan's avatar
Aryan committed
55
56
57
58
59
            layer_2 = layer_splits[2]
            if not layer_2.startswith("attentions_") and not layer_2.startswith("motion_modules_"):
                raise ValueError(
                    f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_' or 'motion_modules_'"
                )
YiYi Xu's avatar
YiYi Xu committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
        r"""
        Set the attention processor for the PAG layers.
        """
        if do_classifier_free_guidance:
            pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0()
        else:
            pag_attn_proc = PAGIdentitySelfAttnProcessor2_0()

        def is_self_attn(module_name):
            r"""
            Check if the module is self-attention module based on its name.
            """
            return "attn1" in module_name and "to" not in name

        def get_block_type(module_name):
            r"""
Aryan's avatar
Aryan committed
78
            Get the block type from the module name. Can be "down", "mid", "up".
YiYi Xu's avatar
YiYi Xu committed
79
80
            """
            # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down"
Aryan's avatar
Aryan committed
81
            # down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "down"
YiYi Xu's avatar
YiYi Xu committed
82
83
84
85
            return module_name.split(".")[0].split("_")[0]

        def get_block_index(module_name):
            r"""
Aryan's avatar
Aryan committed
86
            Get the block index from the module name. Can be "block_0", "block_1", ... If there is only one block (e.g.
YiYi Xu's avatar
YiYi Xu committed
87
88
89
90
            mid_block) and index is ommited from the name, it will be "block_0".
            """
            # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1"
            # mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0"
Aryan's avatar
Aryan committed
91
92
93
            module_name_splits = module_name.split(".")
            block_index = module_name_splits[1]
            if "attentions" in block_index or "motion_modules" in block_index:
YiYi Xu's avatar
YiYi Xu committed
94
95
                return "block_0"
            else:
Aryan's avatar
Aryan committed
96
                return f"block_{block_index}"
YiYi Xu's avatar
YiYi Xu committed
97
98
99

        def get_attn_index(module_name):
            r"""
Aryan's avatar
Aryan committed
100
101
            Get the attention index from the module name. Can be "attentions_0", "attentions_1", "motion_modules_0",
            "motion_modules_1", ...
YiYi Xu's avatar
YiYi Xu committed
102
103
104
            """
            # down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
            # mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
Aryan's avatar
Aryan committed
105
106
107
108
109
110
111
112
113
114
115
116
117
            # down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0"
            # mid_block.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0"
            module_name_split = module_name.split(".")
            mid_name = module_name_split[1]
            down_name = module_name_split[2]
            if "attentions" in down_name:
                return f"attentions_{module_name_split[3]}"
            if "attentions" in mid_name:
                return f"attentions_{module_name_split[2]}"
            if "motion_modules" in down_name:
                return f"motion_modules_{module_name_split[3]}"
            if "motion_modules" in mid_name:
                return f"motion_modules_{module_name_split[2]}"
YiYi Xu's avatar
YiYi Xu committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

        for pag_layer_input in pag_applied_layers:
            # for each PAG layer input, we find corresponding self-attention layers in the unet model
            target_modules = []

            pag_layer_input_splits = pag_layer_input.split(".")

            if len(pag_layer_input_splits) == 1:
                # when the layer input only contains block_type. e.g. "mid", "down", "up"
                block_type = pag_layer_input_splits[0]
                for name, module in self.unet.named_modules():
                    if is_self_attn(name) and get_block_type(name) == block_type:
                        target_modules.append(module)

            elif len(pag_layer_input_splits) == 2:
Aryan's avatar
Aryan committed
133
                # when the layer input contains both block_type and block_index. e.g. "down.block_1", "mid.block_0"
YiYi Xu's avatar
YiYi Xu committed
134
135
136
137
138
139
140
141
142
143
144
                block_type = pag_layer_input_splits[0]
                block_index = pag_layer_input_splits[1]
                for name, module in self.unet.named_modules():
                    if (
                        is_self_attn(name)
                        and get_block_type(name) == block_type
                        and get_block_index(name) == block_index
                    ):
                        target_modules.append(module)

            elif len(pag_layer_input_splits) == 3:
Aryan's avatar
Aryan committed
145
146
                # when the layer input contains block_type, block_index and attention_index.
                # e.g. "down.block_1.attentions_1" or "down.block_1.motion_modules_1"
YiYi Xu's avatar
YiYi Xu committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
                block_type = pag_layer_input_splits[0]
                block_index = pag_layer_input_splits[1]
                attn_index = pag_layer_input_splits[2]

                for name, module in self.unet.named_modules():
                    if (
                        is_self_attn(name)
                        and get_block_type(name) == block_type
                        and get_block_index(name) == block_index
                        and get_attn_index(name) == attn_index
                    ):
                        target_modules.append(module)

            if len(target_modules) == 0:
                raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}")

            for module in target_modules:
                module.processor = pag_attn_proc

    def _get_pag_scale(self, t):
        r"""
        Get the scale factor for the perturbed attention guidance at timestep `t`.
        """

        if self.do_pag_adaptive_scaling:
            signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t)
            if signal_scale < 0:
                signal_scale = 0
            return signal_scale
        else:
            return self.pag_scale

    def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
        r"""
        Apply perturbed attention guidance to the noise prediction.

        Args:
            noise_pred (torch.Tensor): The noise prediction tensor.
            do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
            guidance_scale (float): The scale factor for the guidance term.
            t (int): The current time step.

        Returns:
            torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
        """
        pag_scale = self._get_pag_scale(t)
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
            noise_pred = (
                noise_pred_uncond
                + guidance_scale * (noise_pred_text - noise_pred_uncond)
                + pag_scale * (noise_pred_text - noise_pred_perturb)
            )
        else:
            noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
            noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
        return noise_pred

    def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
        """
        Prepares the perturbed attention guidance for the PAG model.

        Args:
            cond (torch.Tensor): The conditional input tensor.
            uncond (torch.Tensor): The unconditional input tensor.
            do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance.

        Returns:
            torch.Tensor: The prepared perturbed attention guidance tensor.
        """

        cond = torch.cat([cond] * 2, dim=0)

        if do_classifier_free_guidance:
            cond = torch.cat([uncond, cond], dim=0)
        return cond

    def set_pag_applied_layers(self, pag_applied_layers):
        r"""
        set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
        """

        if not isinstance(pag_applied_layers, list):
            pag_applied_layers = [pag_applied_layers]

        for pag_layer in pag_applied_layers:
            self._check_input_pag_applied_layer(pag_layer)

        self.pag_applied_layers = pag_applied_layers

    @property
    def pag_scale(self):
        """
        Get the scale factor for the perturbed attention guidance.
        """
        return self._pag_scale

    @property
    def pag_adaptive_scale(self):
        """
        Get the adaptive scale factor for the perturbed attention guidance.
        """
        return self._pag_adaptive_scale

    @property
    def do_pag_adaptive_scaling(self):
        """
        Check if the adaptive scaling is enabled for the perturbed attention guidance.
        """
        return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0

    @property
    def do_perturbed_attention_guidance(self):
        """
        Check if the perturbed attention guidance is enabled.
        """
        return self._pag_scale > 0 and len(self.pag_applied_layers) > 0

    @property
    def pag_attn_processors(self):
        r"""
        Returns:
            `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
            with the key as the name of the layer.
        """

        processors = {}
        for name, proc in self.unet.attn_processors.items():
            if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0):
                processors[name] = proc
        return processors
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459


class PixArtPAGMixin:
    @staticmethod
    def _check_input_pag_applied_layer(layer):
        r"""
        Check if each layer input in `applied_pag_layers` is valid. It should be the block index: {block_index}.
        """

        # Check if the layer index is valid (should be int or str of int)
        if isinstance(layer, int):
            return  # Valid layer index

        if isinstance(layer, str):
            if layer.isdigit():
                return  # Valid layer index

        # If it is not a valid layer index, raise a ValueError
        raise ValueError(f"Pag layer should only contain block index. Accept number string like '3', got {layer}.")

    def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
        r"""
        Set the attention processor for the PAG layers.
        """
        if do_classifier_free_guidance:
            pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0()
        else:
            pag_attn_proc = PAGIdentitySelfAttnProcessor2_0()

        def is_self_attn(module_name):
            r"""
            Check if the module is self-attention module based on its name.
            """
            return (
                "attn1" in module_name and len(module_name.split(".")) == 3
            )  # include transformer_blocks.1.attn1, exclude transformer_blocks.18.attn1.to_q, transformer_blocks.1.attn1.add_q_proj, ...

        def get_block_index(module_name):
            r"""
            Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g.
            mid_block) and index is ommited from the name, it will be "block_0".
            """
            # transformer_blocks.23.attn -> "23"
            return module_name.split(".")[1]

        for pag_layer_input in pag_applied_layers:
            # for each PAG layer input, we find corresponding self-attention layers in the transformer model
            target_modules = []

            block_index = str(pag_layer_input)

            for name, module in self.transformer.named_modules():
                if is_self_attn(name) and get_block_index(name) == block_index:
                    target_modules.append(module)

            if len(target_modules) == 0:
                raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}")

            for module in target_modules:
                module.processor = pag_attn_proc

    # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.set_pag_applied_layers
    def set_pag_applied_layers(self, pag_applied_layers):
        r"""
        set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
        """

        if not isinstance(pag_applied_layers, list):
            pag_applied_layers = [pag_applied_layers]

        for pag_layer in pag_applied_layers:
            self._check_input_pag_applied_layer(pag_layer)

        self.pag_applied_layers = pag_applied_layers

    # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._get_pag_scale
    def _get_pag_scale(self, t):
        r"""
        Get the scale factor for the perturbed attention guidance at timestep `t`.
        """

        if self.do_pag_adaptive_scaling:
            signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t)
            if signal_scale < 0:
                signal_scale = 0
            return signal_scale
        else:
            return self.pag_scale

    # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._apply_perturbed_attention_guidance
    def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
        r"""
        Apply perturbed attention guidance to the noise prediction.

        Args:
            noise_pred (torch.Tensor): The noise prediction tensor.
            do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
            guidance_scale (float): The scale factor for the guidance term.
            t (int): The current time step.

        Returns:
            torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
        """
        pag_scale = self._get_pag_scale(t)
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
            noise_pred = (
                noise_pred_uncond
                + guidance_scale * (noise_pred_text - noise_pred_uncond)
                + pag_scale * (noise_pred_text - noise_pred_perturb)
            )
        else:
            noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
            noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
        return noise_pred

    # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._prepare_perturbed_attention_guidance
    def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
        """
        Prepares the perturbed attention guidance for the PAG model.

        Args:
            cond (torch.Tensor): The conditional input tensor.
            uncond (torch.Tensor): The unconditional input tensor.
            do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance.

        Returns:
            torch.Tensor: The prepared perturbed attention guidance tensor.
        """

        cond = torch.cat([cond] * 2, dim=0)

        if do_classifier_free_guidance:
            cond = torch.cat([uncond, cond], dim=0)
        return cond

    @property
    # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_scale
    def pag_scale(self):
        """
        Get the scale factor for the perturbed attention guidance.
        """
        return self._pag_scale

    @property
    # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_adaptive_scale
    def pag_adaptive_scale(self):
        """
        Get the adaptive scale factor for the perturbed attention guidance.
        """
        return self._pag_adaptive_scale

    @property
    # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_pag_adaptive_scaling
    def do_pag_adaptive_scaling(self):
        """
        Check if the adaptive scaling is enabled for the perturbed attention guidance.
        """
        return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0

    @property
    # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_perturbed_attention_guidance
    def do_perturbed_attention_guidance(self):
        """
        Check if the perturbed attention guidance is enabled.
        """
        return self._pag_scale > 0 and len(self.pag_applied_layers) > 0

    @property
    # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_attn_processors with unet->transformer
    def pag_attn_processors(self):
        r"""
        Returns:
            `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
            with the key as the name of the layer.
        """

        processors = {}
        for name, proc in self.transformer.attn_processors.items():
            if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0):
                processors[name] = proc
        return processors