unet_2d_condition.py 65.9 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from dataclasses import dataclass
15
from typing import Any, Dict, List, Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
16
17
18

import torch
import torch.nn as nn
19
import torch.utils.checkpoint
Patrick von Platen's avatar
Patrick von Platen committed
20
21

from ..configuration_utils import ConfigMixin, register_to_config
22
from ..loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
Patrick von Platen's avatar
Patrick von Platen committed
23
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
24
from .activations import get_activation
25
26
27
from .attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,
    CROSS_ATTENTION_PROCESSORS,
28
    Attention,
29
30
31
32
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
)
YiYi Xu's avatar
YiYi Xu committed
33
34
from .embeddings import (
    GaussianFourierProjection,
35
    GLIGENTextBoundingboxProjection,
YiYi Xu's avatar
YiYi Xu committed
36
37
38
    ImageHintTimeEmbedding,
    ImageProjection,
    ImageTimeEmbedding,
YiYi Xu's avatar
YiYi Xu committed
39
40
41
42
43
44
    TextImageProjection,
    TextImageTimeEmbedding,
    TextTimeEmbedding,
    TimestepEmbedding,
    Timesteps,
)
45
from .modeling_utils import ModelMixin
46
from .unet_2d_blocks import (
47
    get_down_block,
48
    get_mid_block,
49
50
    get_up_block,
)
Patrick von Platen's avatar
Patrick von Platen committed
51
52


53
54
55
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


56
57
58
@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
Steven Liu's avatar
Steven Liu committed
59
60
    The output of [`UNet2DConditionModel`].

61
62
    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Steven Liu's avatar
Steven Liu committed
63
            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
64
65
    """

66
    sample: torch.FloatTensor = None
67
68


69
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
Kashif Rasul's avatar
Kashif Rasul committed
70
    r"""
Steven Liu's avatar
Steven Liu committed
71
72
    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
    shaped output.
Kashif Rasul's avatar
Kashif Rasul committed
73

Steven Liu's avatar
Steven Liu committed
74
75
    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).
Kashif Rasul's avatar
Kashif Rasul committed
76
77

    Parameters:
78
79
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
Steven Liu's avatar
Steven Liu committed
80
81
        in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
Kashif Rasul's avatar
Kashif Rasul committed
82
        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
Suraj Patil's avatar
Suraj Patil committed
83
        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Kashif Rasul's avatar
Kashif Rasul committed
84
85
86
87
            Whether to flip the sin to cos in the time embedding.
        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            The tuple of downsample blocks to use.
Will Berman's avatar
Will Berman committed
88
        mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
89
            Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
Steven Liu's avatar
Steven Liu committed
90
91
            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
Kashif Rasul's avatar
Kashif Rasul committed
92
            The tuple of upsample blocks to use.
93
94
95
        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
            Whether to include self-attention in the basic transformer blocks, see
            [`~models.attention.BasicTransformerBlock`].
Kashif Rasul's avatar
Kashif Rasul committed
96
97
98
99
100
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
101
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
Kashif Rasul's avatar
Kashif Rasul committed
102
103
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
Steven Liu's avatar
Steven Liu committed
104
            If `None`, normalization and activation layers is skipped in post-processing.
Kashif Rasul's avatar
Kashif Rasul committed
105
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
Sanchit Gandhi's avatar
Sanchit Gandhi committed
106
107
        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
            The dimension of the cross attention features.
108
        transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
109
110
111
            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
112
113
114
115
116
       reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
            blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
117
        encoder_hid_dim (`int`, *optional*, defaults to None):
YiYi Xu's avatar
YiYi Xu committed
118
119
            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
            dimension to `cross_attention_dim`.
Steven Liu's avatar
Steven Liu committed
120
121
        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
YiYi Xu's avatar
YiYi Xu committed
122
            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
Kashif Rasul's avatar
Kashif Rasul committed
123
        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
124
125
        num_attention_heads (`int`, *optional*):
            The number of attention heads. If not defined, defaults to `attention_head_dim`
Will Berman's avatar
Will Berman committed
126
        resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
Steven Liu's avatar
Steven Liu committed
127
128
            for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
        class_embed_type (`str`, *optional*, defaults to `None`):
129
            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
130
            `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
Steven Liu's avatar
Steven Liu committed
131
        addition_embed_type (`str`, *optional*, defaults to `None`):
Patrick von Platen's avatar
Patrick von Platen committed
132
133
            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
            "text". "text" will use the `TextTimeEmbedding` layer.
134
135
        addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
            Dimension for the timestep embeddings.
Steven Liu's avatar
Steven Liu committed
136
        num_class_embeds (`int`, *optional*, defaults to `None`):
137
138
            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
            class conditioning with `class_embed_type` equal to `None`.
Steven Liu's avatar
Steven Liu committed
139
        time_embedding_type (`str`, *optional*, defaults to `positional`):
140
            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
Steven Liu's avatar
Steven Liu committed
141
        time_embedding_dim (`int`, *optional*, defaults to `None`):
Patrick von Platen's avatar
Patrick von Platen committed
142
            An optional override for the dimension of the projected time embedding.
Steven Liu's avatar
Steven Liu committed
143
144
145
146
        time_embedding_act_fn (`str`, *optional*, defaults to `None`):
            Optional activation function to use only once on the time embeddings before they are passed to the rest of
            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
        timestep_post_act (`str`, *optional*, defaults to `None`):
147
            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
Steven Liu's avatar
Steven Liu committed
148
149
        time_cond_proj_dim (`int`, *optional*, defaults to `None`):
            The dimension of `cond_proj` layer in the timestep embedding.
150
151
152
        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
        *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
        *optional*): The dimension of the `class_labels` input when
Steven Liu's avatar
Steven Liu committed
153
            `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
Sanchit Gandhi's avatar
Sanchit Gandhi committed
154
        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
155
156
157
            embeddings with the class embeddings.
        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
Steven Liu's avatar
Steven Liu committed
158
159
160
            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
            `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
            otherwise.
Kashif Rasul's avatar
Kashif Rasul committed
161
162
    """

163
164
    _supports_gradient_checkpointing = True

Patrick von Platen's avatar
Patrick von Platen committed
165
166
167
    @register_to_config
    def __init__(
        self,
Sid Sahai's avatar
Sid Sahai committed
168
169
170
171
172
173
174
175
176
177
178
179
        sample_size: Optional[int] = None,
        in_channels: int = 4,
        out_channels: int = 4,
        center_input_sample: bool = False,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        down_block_types: Tuple[str] = (
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
180
        mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
Sid Sahai's avatar
Sid Sahai committed
181
        up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
182
        only_cross_attention: Union[bool, Tuple[bool]] = False,
Sid Sahai's avatar
Sid Sahai committed
183
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
184
        layers_per_block: Union[int, Tuple[int]] = 2,
Sid Sahai's avatar
Sid Sahai committed
185
186
        downsample_padding: int = 1,
        mid_block_scale_factor: float = 1,
187
        dropout: float = 0.0,
Sid Sahai's avatar
Sid Sahai committed
188
        act_fn: str = "silu",
189
        norm_num_groups: Optional[int] = 32,
Sid Sahai's avatar
Sid Sahai committed
190
        norm_eps: float = 1e-5,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
191
        cross_attention_dim: Union[int, Tuple[int]] = 1280,
192
193
        transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
        reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
William Berman's avatar
William Berman committed
194
        encoder_hid_dim: Optional[int] = None,
YiYi Xu's avatar
YiYi Xu committed
195
        encoder_hid_dim_type: Optional[str] = None,
Suraj Patil's avatar
Suraj Patil committed
196
        attention_head_dim: Union[int, Tuple[int]] = 8,
197
        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
198
        dual_cross_attention: bool = False,
Suraj Patil's avatar
Suraj Patil committed
199
        use_linear_projection: bool = False,
Will Berman's avatar
Will Berman committed
200
        class_embed_type: Optional[str] = None,
Patrick von Platen's avatar
Patrick von Platen committed
201
        addition_embed_type: Optional[str] = None,
202
        addition_time_embed_dim: Optional[int] = None,
203
        num_class_embeds: Optional[int] = None,
204
        upcast_attention: bool = False,
Will Berman's avatar
Will Berman committed
205
        resnet_time_scale_shift: str = "default",
206
207
        resnet_skip_time_act: bool = False,
        resnet_out_scale_factor: int = 1.0,
208
        time_embedding_type: str = "positional",
Patrick von Platen's avatar
Patrick von Platen committed
209
        time_embedding_dim: Optional[int] = None,
210
        time_embedding_act_fn: Optional[str] = None,
211
212
213
214
        timestep_post_act: Optional[str] = None,
        time_cond_proj_dim: Optional[int] = None,
        conv_in_kernel: int = 3,
        conv_out_kernel: int = 3,
Will Berman's avatar
Will Berman committed
215
        projection_class_embeddings_input_dim: Optional[int] = None,
216
        attention_type: str = "default",
Sanchit Gandhi's avatar
Sanchit Gandhi committed
217
        class_embeddings_concat: bool = False,
218
        mid_block_only_cross_attention: Optional[bool] = None,
219
        cross_attention_norm: Optional[str] = None,
Patrick von Platen's avatar
Patrick von Platen committed
220
        addition_embed_type_num_heads=64,
Patrick von Platen's avatar
Patrick von Platen committed
221
222
223
224
225
    ):
        super().__init__()

        self.sample_size = sample_size

226
227
228
229
230
        if num_attention_heads is not None:
            raise ValueError(
                "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
            )

231
232
233
234
235
236
237
238
        # If `num_attention_heads` is not defined (which is the case for most models)
        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
        # The reason for this behavior is to correct for incorrectly named variables that were introduced
        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
        # which is why we correct for the naming here.
        num_attention_heads = num_attention_heads or attention_head_dim

Will Berman's avatar
Will Berman committed
239
        # Check inputs
240
241
242
243
244
245
246
247
248
249
250
251
        self._check_config(
            down_block_types=down_block_types,
            up_block_types=up_block_types,
            only_cross_attention=only_cross_attention,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            cross_attention_dim=cross_attention_dim,
            transformer_layers_per_block=transformer_layers_per_block,
            reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
            attention_head_dim=attention_head_dim,
            num_attention_heads=num_attention_heads,
        )
252

Patrick von Platen's avatar
Patrick von Platen committed
253
        # input
254
255
256
257
        conv_in_padding = (conv_in_kernel - 1) // 2
        self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
        )
Patrick von Platen's avatar
Patrick von Platen committed
258
259

        # time
260
261
262
263
264
265
266
        time_embed_dim, timestep_input_dim = self._set_time_proj(
            time_embedding_type,
            block_out_channels=block_out_channels,
            flip_sin_to_cos=flip_sin_to_cos,
            freq_shift=freq_shift,
            time_embedding_dim=time_embedding_dim,
        )
Patrick von Platen's avatar
Patrick von Platen committed
267

268
269
270
271
272
273
274
        self.time_embedding = TimestepEmbedding(
            timestep_input_dim,
            time_embed_dim,
            act_fn=act_fn,
            post_act_fn=timestep_post_act,
            cond_proj_dim=time_cond_proj_dim,
        )
Patrick von Platen's avatar
Patrick von Platen committed
275

276
277
278
279
280
        self._set_encoder_hid_proj(
            encoder_hid_dim_type,
            cross_attention_dim=cross_attention_dim,
            encoder_hid_dim=encoder_hid_dim,
        )
William Berman's avatar
William Berman committed
281

282
        # class embedding
283
284
285
286
287
288
289
290
        self._set_class_embedding(
            class_embed_type,
            act_fn=act_fn,
            num_class_embeds=num_class_embeds,
            projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
            time_embed_dim=time_embed_dim,
            timestep_input_dim=timestep_input_dim,
        )
Patrick von Platen's avatar
Patrick von Platen committed
291

292
293
294
295
296
297
298
299
300
301
302
        self._set_add_embedding(
            addition_embed_type,
            addition_embed_type_num_heads=addition_embed_type_num_heads,
            addition_time_embed_dim=addition_time_embed_dim,
            cross_attention_dim=cross_attention_dim,
            encoder_hid_dim=encoder_hid_dim,
            flip_sin_to_cos=flip_sin_to_cos,
            freq_shift=freq_shift,
            projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
            time_embed_dim=time_embed_dim,
        )
Patrick von Platen's avatar
Patrick von Platen committed
303

304
305
306
        if time_embedding_act_fn is None:
            self.time_embed_act = None
        else:
307
            self.time_embed_act = get_activation(time_embedding_act_fn)
308

Patrick von Platen's avatar
Patrick von Platen committed
309
310
311
        self.down_blocks = nn.ModuleList([])
        self.up_blocks = nn.ModuleList([])

312
        # set or unroll configs
313
        if isinstance(only_cross_attention, bool):
314
315
316
            if mid_block_only_cross_attention is None:
                mid_block_only_cross_attention = only_cross_attention

317
318
            only_cross_attention = [only_cross_attention] * len(down_block_types)

319
320
321
        if mid_block_only_cross_attention is None:
            mid_block_only_cross_attention = False

322
323
324
        if isinstance(num_attention_heads, int):
            num_attention_heads = (num_attention_heads,) * len(down_block_types)

Suraj Patil's avatar
Suraj Patil committed
325
326
327
        if isinstance(attention_head_dim, int):
            attention_head_dim = (attention_head_dim,) * len(down_block_types)

Sanchit Gandhi's avatar
Sanchit Gandhi committed
328
329
330
        if isinstance(cross_attention_dim, int):
            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)

331
332
333
        if isinstance(layers_per_block, int):
            layers_per_block = [layers_per_block] * len(down_block_types)

334
335
336
        if isinstance(transformer_layers_per_block, int):
            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)

Sanchit Gandhi's avatar
Sanchit Gandhi committed
337
338
339
340
341
342
343
344
        if class_embeddings_concat:
            # The time embeddings are concatenated with the class embeddings. The dimension of the
            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
            # regular time embeddings
            blocks_time_embed_dim = time_embed_dim * 2
        else:
            blocks_time_embed_dim = time_embed_dim

Patrick von Platen's avatar
Patrick von Platen committed
345
346
347
348
349
350
351
352
353
        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
354
                num_layers=layers_per_block[i],
355
                transformer_layers_per_block=transformer_layers_per_block[i],
Patrick von Platen's avatar
Patrick von Platen committed
356
357
                in_channels=input_channel,
                out_channels=output_channel,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
358
                temb_channels=blocks_time_embed_dim,
Patrick von Platen's avatar
Patrick von Platen committed
359
360
361
                add_downsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
362
                resnet_groups=norm_num_groups,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
363
                cross_attention_dim=cross_attention_dim[i],
364
                num_attention_heads=num_attention_heads[i],
Patrick von Platen's avatar
Patrick von Platen committed
365
                downsample_padding=downsample_padding,
366
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
367
                use_linear_projection=use_linear_projection,
368
                only_cross_attention=only_cross_attention[i],
369
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
370
                resnet_time_scale_shift=resnet_time_scale_shift,
371
                attention_type=attention_type,
372
373
                resnet_skip_time_act=resnet_skip_time_act,
                resnet_out_scale_factor=resnet_out_scale_factor,
374
                cross_attention_norm=cross_attention_norm,
375
                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
376
                dropout=dropout,
Patrick von Platen's avatar
Patrick von Platen committed
377
378
379
380
            )
            self.down_blocks.append(down_block)

        # mid
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        self.mid_block = get_mid_block(
            mid_block_type,
            temb_channels=blocks_time_embed_dim,
            in_channels=block_out_channels[-1],
            resnet_eps=norm_eps,
            resnet_act_fn=act_fn,
            resnet_groups=norm_num_groups,
            output_scale_factor=mid_block_scale_factor,
            transformer_layers_per_block=transformer_layers_per_block[-1],
            num_attention_heads=num_attention_heads[-1],
            cross_attention_dim=cross_attention_dim[-1],
            dual_cross_attention=dual_cross_attention,
            use_linear_projection=use_linear_projection,
            mid_block_only_cross_attention=mid_block_only_cross_attention,
            upcast_attention=upcast_attention,
            resnet_time_scale_shift=resnet_time_scale_shift,
            attention_type=attention_type,
            resnet_skip_time_act=resnet_skip_time_act,
            cross_attention_norm=cross_attention_norm,
            attention_head_dim=attention_head_dim[-1],
            dropout=dropout,
        )
Patrick von Platen's avatar
Patrick von Platen committed
403

404
405
406
        # count how many layers upsample the images
        self.num_upsamplers = 0

Patrick von Platen's avatar
Patrick von Platen committed
407
408
        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
409
        reversed_num_attention_heads = list(reversed(num_attention_heads))
410
        reversed_layers_per_block = list(reversed(layers_per_block))
Sanchit Gandhi's avatar
Sanchit Gandhi committed
411
        reversed_cross_attention_dim = list(reversed(cross_attention_dim))
412
413
414
415
416
        reversed_transformer_layers_per_block = (
            list(reversed(transformer_layers_per_block))
            if reverse_transformer_layers_per_block is None
            else reverse_transformer_layers_per_block
        )
417
        only_cross_attention = list(reversed(only_cross_attention))
418

Patrick von Platen's avatar
Patrick von Platen committed
419
420
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
421
422
            is_final_block = i == len(block_out_channels) - 1

Patrick von Platen's avatar
Patrick von Platen committed
423
424
425
426
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

427
428
429
430
431
432
            # add upsample block for all BUT final layer
            if not is_final_block:
                add_upsample = True
                self.num_upsamplers += 1
            else:
                add_upsample = False
Patrick von Platen's avatar
Patrick von Platen committed
433
434
435

            up_block = get_up_block(
                up_block_type,
436
                num_layers=reversed_layers_per_block[i] + 1,
437
                transformer_layers_per_block=reversed_transformer_layers_per_block[i],
Patrick von Platen's avatar
Patrick von Platen committed
438
439
440
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
441
                temb_channels=blocks_time_embed_dim,
442
                add_upsample=add_upsample,
Patrick von Platen's avatar
Patrick von Platen committed
443
444
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
445
                resolution_idx=i,
446
                resnet_groups=norm_num_groups,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
447
                cross_attention_dim=reversed_cross_attention_dim[i],
448
                num_attention_heads=reversed_num_attention_heads[i],
449
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
450
                use_linear_projection=use_linear_projection,
451
                only_cross_attention=only_cross_attention[i],
452
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
453
                resnet_time_scale_shift=resnet_time_scale_shift,
454
                attention_type=attention_type,
455
456
                resnet_skip_time_act=resnet_skip_time_act,
                resnet_out_scale_factor=resnet_out_scale_factor,
457
                cross_attention_norm=cross_attention_norm,
458
                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
459
                dropout=dropout,
Patrick von Platen's avatar
Patrick von Platen committed
460
461
462
463
464
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # out
465
466
467
468
        if norm_num_groups is not None:
            self.conv_norm_out = nn.GroupNorm(
                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
            )
469
            self.conv_act = get_activation(act_fn)
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
        else:
            self.conv_norm_out = None
            self.conv_act = None

        conv_out_padding = (conv_out_kernel - 1) // 2
        self.conv_out = nn.Conv2d(
            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
        )

        self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)

    def _check_config(
        self,
        down_block_types: Tuple[str],
        up_block_types: Tuple[str],
        only_cross_attention: Union[bool, Tuple[bool]],
        block_out_channels: Tuple[int],
        layers_per_block: [int, Tuple[int]],
        cross_attention_dim: Union[int, Tuple[int]],
        transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]],
        reverse_transformer_layers_per_block: bool,
        attention_head_dim: int,
        num_attention_heads: Optional[Union[int, Tuple[int]]],
    ):
        if len(down_block_types) != len(up_block_types):
            raise ValueError(
                f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
            )

        if len(block_out_channels) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
            )

        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
            )
        if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
            for layer_number_per_block in transformer_layers_per_block:
                if isinstance(layer_number_per_block, list):
                    raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")

    def _set_time_proj(
        self,
        time_embedding_type: str,
        block_out_channels: int,
        flip_sin_to_cos: bool,
        freq_shift: float,
        time_embedding_dim: int,
    ) -> Tuple[int, int]:
        if time_embedding_type == "fourier":
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
            if time_embed_dim % 2 != 0:
                raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
            self.time_proj = GaussianFourierProjection(
                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
            )
            timestep_input_dim = time_embed_dim
        elif time_embedding_type == "positional":
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4

            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
            timestep_input_dim = block_out_channels[0]
        else:
            raise ValueError(
                f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
            )

        return time_embed_dim, timestep_input_dim

    def _set_encoder_hid_proj(
        self,
        encoder_hid_dim_type: Optional[str],
        cross_attention_dim: Union[int, Tuple[int]],
        encoder_hid_dim: Optional[int],
    ):
        if encoder_hid_dim_type is None and encoder_hid_dim is not None:
            encoder_hid_dim_type = "text_proj"
            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
            logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")

        if encoder_hid_dim is None and encoder_hid_dim_type is not None:
            raise ValueError(
                f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
            )

        if encoder_hid_dim_type == "text_proj":
            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
        elif encoder_hid_dim_type == "text_image_proj":
            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
            # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
            self.encoder_hid_proj = TextImageProjection(
                text_embed_dim=encoder_hid_dim,
                image_embed_dim=cross_attention_dim,
                cross_attention_dim=cross_attention_dim,
            )
        elif encoder_hid_dim_type == "image_proj":
            # Kandinsky 2.2
            self.encoder_hid_proj = ImageProjection(
                image_embed_dim=encoder_hid_dim,
                cross_attention_dim=cross_attention_dim,
            )
        elif encoder_hid_dim_type is not None:
            raise ValueError(
                f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
            )
        else:
            self.encoder_hid_proj = None

    def _set_class_embedding(
        self,
        class_embed_type: Optional[str],
        act_fn: str,
        num_class_embeds: Optional[int],
        projection_class_embeddings_input_dim: Optional[int],
        time_embed_dim: int,
        timestep_input_dim: int,
    ):
        if class_embed_type is None and num_class_embeds is not None:
            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
        elif class_embed_type == "timestep":
            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
        elif class_embed_type == "identity":
            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
        elif class_embed_type == "projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
                )
            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
            # 2. it projects from an arbitrary input dimension.
            #
            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
        elif class_embed_type == "simple_projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
                )
            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
        else:
            self.class_embedding = None
637

638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
    def _set_add_embedding(
        self,
        addition_embed_type: str,
        addition_embed_type_num_heads: int,
        addition_time_embed_dim: Optional[int],
        flip_sin_to_cos: bool,
        freq_shift: float,
        cross_attention_dim: Optional[int],
        encoder_hid_dim: Optional[int],
        projection_class_embeddings_input_dim: Optional[int],
        time_embed_dim: int,
    ):
        if addition_embed_type == "text":
            if encoder_hid_dim is not None:
                text_time_embedding_from_dim = encoder_hid_dim
            else:
                text_time_embedding_from_dim = cross_attention_dim
655

656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
            self.add_embedding = TextTimeEmbedding(
                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
            )
        elif addition_embed_type == "text_image":
            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
            # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
            self.add_embedding = TextImageTimeEmbedding(
                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
            )
        elif addition_embed_type == "text_time":
            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
        elif addition_embed_type == "image":
            # Kandinsky 2.2
            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
        elif addition_embed_type == "image_hint":
            # Kandinsky 2.2 ControlNet
            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
        elif addition_embed_type is not None:
            raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
Patrick von Platen's avatar
Patrick von Platen committed
677

678
    def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
679
        if attention_type in ["gated", "gated-text-image"]:
680
681
682
683
684
            positive_len = 768
            if isinstance(cross_attention_dim, int):
                positive_len = cross_attention_dim
            elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
                positive_len = cross_attention_dim[0]
685
686

            feature_type = "text-only" if attention_type == "gated" else "text-image"
687
            self.position_net = GLIGENTextBoundingboxProjection(
688
689
                positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
            )
690

691
    @property
Patrick von Platen's avatar
Patrick von Platen committed
692
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
693
694
695
696
697
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
698
        # set recursively
699
700
        processors = {}

Patrick von Platen's avatar
Patrick von Platen committed
701
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
702
703
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
704
705
706
707
708
709
710
711
712
713
714

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

715
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
716
        r"""
Steven Liu's avatar
Steven Liu committed
717
718
        Sets the attention processor to use to compute attention.

719
        Parameters:
Steven Liu's avatar
Steven Liu committed
720
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
721
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
Steven Liu's avatar
Steven Liu committed
722
723
724
725
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.
726
727
728
729
730
731
732
733
734
735
736

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
737
            if hasattr(module, "set_processor"):
738
                if not isinstance(processor, dict):
739
                    module.set_processor(processor)
740
                else:
741
                    module.set_processor(processor.pop(f"{name}.processor"))
742

743
744
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
745

746
747
        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)
748

749
750
751
752
    def set_default_attn_processor(self):
        """
        Disables custom attention processors and sets the default attention implementation.
        """
753
754
755
756
757
758
759
760
761
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            processor = AttnAddedKVProcessor()
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            processor = AttnProcessor()
        else:
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )

762
        self.set_attn_processor(processor)
763

764
    def set_attention_slice(self, slice_size):
765
766
        r"""
        Enable sliced attention computation.
767

Steven Liu's avatar
Steven Liu committed
768
769
        When this option is enabled, the attention module splits the input tensor in slices to compute attention in
        several steps. This is useful for saving some memory in exchange for a small decrease in speed.
770

771
772
        Args:
            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
Steven Liu's avatar
Steven Liu committed
773
774
775
776
                When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
                `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
                must be a multiple of `slice_size`.
777
778
779
        """
        sliceable_head_dims = []

Alexander Pivovarov's avatar
Alexander Pivovarov committed
780
        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
781
782
783
784
            if hasattr(module, "set_attention_slice"):
                sliceable_head_dims.append(module.sliceable_head_dim)

            for child in module.children():
Alexander Pivovarov's avatar
Alexander Pivovarov committed
785
                fn_recursive_retrieve_sliceable_dims(child)
786
787
788

        # retrieve number of attention layers
        for module in self.children():
Alexander Pivovarov's avatar
Alexander Pivovarov committed
789
            fn_recursive_retrieve_sliceable_dims(module)
790

Alexander Pivovarov's avatar
Alexander Pivovarov committed
791
        num_sliceable_layers = len(sliceable_head_dims)
792
793
794
795
796
797
798

        if slice_size == "auto":
            # half the attention head size is usually a good trade-off between
            # speed and memory
            slice_size = [dim // 2 for dim in sliceable_head_dims]
        elif slice_size == "max":
            # make smallest slice possible
Alexander Pivovarov's avatar
Alexander Pivovarov committed
799
            slice_size = num_sliceable_layers * [1]
800

Alexander Pivovarov's avatar
Alexander Pivovarov committed
801
        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
802
803
804
805
806
807

        if len(slice_size) != len(sliceable_head_dims):
            raise ValueError(
                f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
                f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
            )
808

809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
        for i in range(len(slice_size)):
            size = slice_size[i]
            dim = sliceable_head_dims[i]
            if size is not None and size > dim:
                raise ValueError(f"size {size} has to be smaller or equal to {dim}.")

        # Recursively walk through all the children.
        # Any children which exposes the set_attention_slice method
        # gets the message
        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
            if hasattr(module, "set_attention_slice"):
                module.set_attention_slice(slice_size.pop())

            for child in module.children():
                fn_recursive_set_attention_slice(child, slice_size)

        reversed_slice_size = list(reversed(slice_size))
        for module in self.children():
            fn_recursive_set_attention_slice(module, reversed_slice_size)
828

829
    def _set_gradient_checkpointing(self, module, value=False):
830
        if hasattr(module, "gradient_checkpointing"):
831
832
            module.gradient_checkpointing = value

833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
    def enable_freeu(self, s1, s2, b1, b2):
        r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.

        The suffixes after the scaling factors represent the stage blocks where they are being applied.

        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
        are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.

        Args:
            s1 (`float`):
                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
                mitigate the "oversmoothing effect" in the enhanced denoising process.
            s2 (`float`):
                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
                mitigate the "oversmoothing effect" in the enhanced denoising process.
            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
        """
        for i, upsample_block in enumerate(self.up_blocks):
            setattr(upsample_block, "s1", s1)
            setattr(upsample_block, "s2", s2)
            setattr(upsample_block, "b1", b1)
            setattr(upsample_block, "b2", b2)

    def disable_freeu(self):
        """Disables the FreeU mechanism."""
        freeu_keys = {"s1", "s2", "b1", "b2"}
        for i, upsample_block in enumerate(self.up_blocks):
            for k in freeu_keys:
862
                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
863
864
                    setattr(upsample_block, k, None)

865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
    def fuse_qkv_projections(self):
        """
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
        key, value) are fused. For cross-attention modules, key and value projection matrices are fused.

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>
        """
        self.original_attn_processors = None

        for _, attn_processor in self.attn_processors.items():
            if "Added" in str(attn_processor.__class__.__name__):
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

        self.original_attn_processors = self.attn_processors

        for module in self.modules():
            if isinstance(module, Attention):
                module.fuse_projections(fuse=True)

    def unfuse_qkv_projections(self):
        """Disables the fused QKV projection if enabled.

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>

        """
        if self.original_attn_processors is not None:
            self.set_attn_processor(self.original_attn_processors)

901
902
903
904
905
906
907
908
909
910
911
    def unload_lora(self):
        """Unloads LoRA weights."""
        deprecate(
            "unload_lora",
            "0.28.0",
            "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
        )
        for module in self.modules():
            if hasattr(module, "set_lora_layer"):
                module.set_lora_layer(None)

912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
    def get_time_embed(
        self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
    ) -> Optional[torch.Tensor]:
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        t_emb = self.time_proj(timesteps)
        # `Timesteps` does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=sample.dtype)
        return t_emb

    def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        class_emb = None
        if self.class_embedding is not None:
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

                # `Timesteps` does not contain any weights and will always return f32 tensors
                # there might be better ways to encapsulate this.
                class_labels = class_labels.to(dtype=sample.dtype)

            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
        return class_emb

    def get_aug_embed(
        self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict
    ) -> Optional[torch.Tensor]:
        aug_emb = None
        if self.config.addition_embed_type == "text":
            aug_emb = self.add_embedding(encoder_hidden_states)
        elif self.config.addition_embed_type == "text_image":
            # Kandinsky 2.1 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
                )

            image_embs = added_cond_kwargs.get("image_embeds")
            text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
            aug_emb = self.add_embedding(text_embs, image_embs)
        elif self.config.addition_embed_type == "text_time":
            # SDXL - style
            if "text_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
                )
            text_embeds = added_cond_kwargs.get("text_embeds")
            if "time_ids" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
                )
            time_ids = added_cond_kwargs.get("time_ids")
            time_embeds = self.add_time_proj(time_ids.flatten())
            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
            add_embeds = add_embeds.to(emb.dtype)
            aug_emb = self.add_embedding(add_embeds)
        elif self.config.addition_embed_type == "image":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
                )
            image_embs = added_cond_kwargs.get("image_embeds")
            aug_emb = self.add_embedding(image_embs)
        elif self.config.addition_embed_type == "image_hint":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
                )
            image_embs = added_cond_kwargs.get("image_embeds")
            hint = added_cond_kwargs.get("hint")
            aug_emb = self.add_embedding(image_embs, hint)
        return aug_emb

    def process_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor, added_cond_kwargs) -> torch.Tensor:
        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
            # Kadinsky 2.1 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )

            image_embeds = added_cond_kwargs.get("image_embeds")
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )
            image_embeds = added_cond_kwargs.get("image_embeds")
            encoder_hidden_states = self.encoder_hid_proj(image_embeds)
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )
            image_embeds = added_cond_kwargs.get("image_embeds")
            image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
            encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
        return encoder_hidden_states

Patrick von Platen's avatar
Patrick von Platen committed
1036
1037
1038
1039
1040
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
1041
        class_labels: Optional[torch.Tensor] = None,
1042
        timestep_cond: Optional[torch.Tensor] = None,
Will Berman's avatar
Will Berman committed
1043
        attention_mask: Optional[torch.Tensor] = None,
1044
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
YiYi Xu's avatar
YiYi Xu committed
1045
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1046
1047
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
1048
        down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1049
        encoder_attention_mask: Optional[torch.Tensor] = None,
1050
1051
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
1052
        r"""
Steven Liu's avatar
Steven Liu committed
1053
1054
        The [`UNet2DConditionModel`] forward method.

Kashif Rasul's avatar
Kashif Rasul committed
1055
        Args:
Steven Liu's avatar
Steven Liu committed
1056
1057
1058
1059
1060
            sample (`torch.FloatTensor`):
                The noisy input tensor with the following shape `(batch, channel, height, width)`.
            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
            encoder_hidden_states (`torch.FloatTensor`):
                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
            class_labels (`torch.Tensor`, *optional*, defaults to `None`):
                Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
            timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
                Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
                through the `self.time_embedding` layer to obtain the timestep embeddings.
            attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
                negative values to the attention scores corresponding to "discard" tokens.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            added_cond_kwargs: (`dict`, *optional*):
                A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
                are passed along to the UNet blocks.
            down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
                A tuple of tensors that if specified are added to the residuals of down unet blocks.
            mid_block_additional_residual: (`torch.Tensor`, *optional*):
                A tensor that if specified is added to the residual of the middle unet block.
1081
            encoder_attention_mask (`torch.Tensor`):
Steven Liu's avatar
Steven Liu committed
1082
1083
1084
                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
                which adds large negative values to the attention scores corresponding to "discard" tokens.
Kashif Rasul's avatar
Kashif Rasul committed
1085
            return_dict (`bool`, *optional*, defaults to `True`):
Steven Liu's avatar
Steven Liu committed
1086
1087
                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
                tuple.
1088
            cross_attention_kwargs (`dict`, *optional*):
Steven Liu's avatar
Steven Liu committed
1089
                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
1090
1091
1092
            added_cond_kwargs: (`dict`, *optional*):
                A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
                are passed along to the UNet blocks.
1093
            down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
Patrick von Platen's avatar
Patrick von Platen committed
1094
1095
                additional residuals to be added to UNet long skip connections from down blocks to up blocks for
                example from ControlNet side model(s)
1096
1097
1098
1099
            mid_block_additional_residual (`torch.Tensor`, *optional*):
                additional residual to be added to UNet mid block output, for example from ControlNet side model
            down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
                additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Kashif Rasul's avatar
Kashif Rasul committed
1100
1101
1102

        Returns:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
Steven Liu's avatar
Steven Liu committed
1103
1104
                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
                a `tuple` is returned where the first element is the sample tensor.
Kashif Rasul's avatar
Kashif Rasul committed
1105
        """
1106
        # By default samples have to be AT least a multiple of the overall upsampling factor.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
1107
        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1108
1109
1110
1111
1112
1113
1114
1115
        # However, the upsampling interpolation output size can be forced to fit any upsampling size
        # on the fly if necessary.
        default_overall_up_factor = 2**self.num_upsamplers

        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
        forward_upsample_size = False
        upsample_size = None

1116
1117
1118
1119
1120
        for dim in sample.shape[-2:]:
            if dim % default_overall_up_factor != 0:
                # Forward upsample size to force interpolation output size.
                forward_upsample_size = True
                break
1121

1122
1123
1124
1125
1126
1127
1128
1129
        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
        # expects mask of shape:
        #   [batch, key_tokens]
        # adds singleton query_tokens dimension:
        #   [batch,                    1, key_tokens]
        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
Will Berman's avatar
Will Berman committed
1130
        if attention_mask is not None:
1131
1132
1133
1134
            # assume that mask is expressed as:
            #   (1 = keep,      0 = discard)
            # convert mask into a bias that can be added to attention scores:
            #       (keep = +0,     discard = -10000.0)
Will Berman's avatar
Will Berman committed
1135
1136
1137
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
            attention_mask = attention_mask.unsqueeze(1)

1138
1139
1140
1141
1142
        # convert encoder_attention_mask to a bias the same way we do for attention_mask
        if encoder_attention_mask is not None:
            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

Patrick von Platen's avatar
Patrick von Platen committed
1143
1144
1145
1146
1147
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
1148
        t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1149
        emb = self.time_embedding(t_emb, timestep_cond)
1150
        aug_emb = None
Patrick von Platen's avatar
Patrick von Platen committed
1151

1152
1153
        class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
        if class_emb is not None:
Sanchit Gandhi's avatar
Sanchit Gandhi committed
1154
1155
1156
1157
            if self.config.class_embeddings_concat:
                emb = torch.cat([emb, class_emb], dim=-1)
            else:
                emb = emb + class_emb
1158

1159
1160
1161
1162
1163
        aug_emb = self.get_aug_embed(
            emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
        )
        if self.config.addition_embed_type == "image_hint":
            aug_emb, hint = aug_emb
YiYi Xu's avatar
YiYi Xu committed
1164
            sample = torch.cat([sample, hint], dim=1)
1165
        emb = emb + aug_emb if aug_emb is not None else emb
Patrick von Platen's avatar
Patrick von Platen committed
1166

1167
1168
1169
        if self.time_embed_act is not None:
            emb = self.time_embed_act(emb)

1170
1171
1172
        encoder_hidden_states = self.process_encoder_hidden_states(
            encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
        )
1173

Patrick von Platen's avatar
Patrick von Platen committed
1174
1175
1176
        # 2. pre-process
        sample = self.conv_in(sample)

1177
1178
1179
1180
1181
1182
        # 2.5 GLIGEN position net
        if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
            cross_attention_kwargs = cross_attention_kwargs.copy()
            gligen_args = cross_attention_kwargs.pop("gligen")
            cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}

Patrick von Platen's avatar
Patrick von Platen committed
1183
        # 3. down
1184
        lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1185
1186
1187
        if USE_PEFT_BACKEND:
            # weight the lora layers by setting `lora_scale` for each PEFT layer
            scale_lora_layers(self, lora_scale)
Will Berman's avatar
Will Berman committed
1188
1189

        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1190
1191
1192
1193
1194
1195
        # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
        is_adapter = down_intrablock_additional_residuals is not None
        # maintain backward compatibility for legacy usage, where
        #       T2I-Adapter and ControlNet both use down_block_additional_residuals arg
        #       but can only use one or the other
        if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
Patrick von Platen's avatar
Patrick von Platen committed
1196
1197
1198
1199
            deprecate(
                "T2I should not use down_block_additional_residuals",
                "1.3.0",
                "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1200
1201
                       and will be removed in diffusers 1.3.0.  `down_block_additional_residuals` should only be used \
                       for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
Patrick von Platen's avatar
Patrick von Platen committed
1202
1203
                standard_warn=False,
            )
1204
1205
            down_intrablock_additional_residuals = down_block_additional_residuals
            is_adapter = True
Will Berman's avatar
Will Berman committed
1206

Patrick von Platen's avatar
Patrick von Platen committed
1207
1208
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
1209
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
Will Berman's avatar
Will Berman committed
1210
1211
                # For t2i-adapter CrossAttnDownBlock2D
                additional_residuals = {}
1212
1213
                if is_adapter and len(down_intrablock_additional_residuals) > 0:
                    additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
Will Berman's avatar
Will Berman committed
1214

Patrick von Platen's avatar
Patrick von Platen committed
1215
                sample, res_samples = downsample_block(
1216
1217
1218
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
Will Berman's avatar
Will Berman committed
1219
                    attention_mask=attention_mask,
1220
                    cross_attention_kwargs=cross_attention_kwargs,
1221
                    encoder_attention_mask=encoder_attention_mask,
Will Berman's avatar
Will Berman committed
1222
                    **additional_residuals,
Patrick von Platen's avatar
Patrick von Platen committed
1223
1224
                )
            else:
1225
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
1226
1227
                if is_adapter and len(down_intrablock_additional_residuals) > 0:
                    sample += down_intrablock_additional_residuals.pop(0)
Will Berman's avatar
Will Berman committed
1228

Patrick von Platen's avatar
Patrick von Platen committed
1229
1230
            down_block_res_samples += res_samples

Will Berman's avatar
Will Berman committed
1231
        if is_controlnet:
1232
1233
1234
1235
1236
            new_down_block_res_samples = ()

            for down_block_res_sample, down_block_additional_residual in zip(
                down_block_res_samples, down_block_additional_residuals
            ):
1237
                down_block_res_sample = down_block_res_sample + down_block_additional_residual
1238
                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1239
1240
1241

            down_block_res_samples = new_down_block_res_samples

Patrick von Platen's avatar
Patrick von Platen committed
1242
        # 4. mid
1243
        if self.mid_block is not None:
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
            if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
                sample = self.mid_block(
                    sample,
                    emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=attention_mask,
                    cross_attention_kwargs=cross_attention_kwargs,
                    encoder_attention_mask=encoder_attention_mask,
                )
            else:
                sample = self.mid_block(sample, emb)

1256
1257
1258
            # To support T2I-Adapter-XL
            if (
                is_adapter
1259
1260
                and len(down_intrablock_additional_residuals) > 0
                and sample.shape == down_intrablock_additional_residuals[0].shape
1261
            ):
1262
                sample += down_intrablock_additional_residuals.pop(0)
Patrick von Platen's avatar
Patrick von Platen committed
1263

Will Berman's avatar
Will Berman committed
1264
        if is_controlnet:
1265
            sample = sample + mid_block_additional_residual
1266

Patrick von Platen's avatar
Patrick von Platen committed
1267
        # 5. up
1268
1269
1270
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

Patrick von Platen's avatar
Patrick von Platen committed
1271
1272
1273
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

1274
1275
1276
1277
1278
            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

1279
            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
1280
1281
1282
1283
1284
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
1285
                    cross_attention_kwargs=cross_attention_kwargs,
1286
                    upsample_size=upsample_size,
Will Berman's avatar
Will Berman committed
1287
                    attention_mask=attention_mask,
1288
                    encoder_attention_mask=encoder_attention_mask,
Patrick von Platen's avatar
Patrick von Platen committed
1289
1290
                )
            else:
1291
                sample = upsample_block(
1292
1293
1294
1295
1296
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    upsample_size=upsample_size,
                    scale=lora_scale,
1297
                )
1298

Patrick von Platen's avatar
Patrick von Platen committed
1299
        # 6. post-process
1300
1301
1302
        if self.conv_norm_out:
            sample = self.conv_norm_out(sample)
            sample = self.conv_act(sample)
Patrick von Platen's avatar
Patrick von Platen committed
1303
1304
        sample = self.conv_out(sample)

1305
1306
        if USE_PEFT_BACKEND:
            # remove `lora_scale` from each PEFT layer
1307
            unscale_lora_layers(self, lora_scale)
1308

1309
1310
        if not return_dict:
            return (sample,)
Patrick von Platen's avatar
Patrick von Platen committed
1311

1312
        return UNet2DConditionOutput(sample=sample)