unet_2d_condition.py 43.5 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from dataclasses import dataclass
15
from typing import Any, Dict, List, Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
16
17
18

import torch
import torch.nn as nn
19
import torch.nn.functional as F
20
import torch.utils.checkpoint
Patrick von Platen's avatar
Patrick von Platen committed
21
22

from ..configuration_utils import ConfigMixin, register_to_config
23
from ..loaders import UNet2DConditionLoadersMixin
24
from ..utils import BaseOutput, logging
25
from .attention_processor import AttentionProcessor, AttnProcessor
YiYi Xu's avatar
YiYi Xu committed
26
27
28
29
30
31
32
33
from .embeddings import (
    GaussianFourierProjection,
    TextImageProjection,
    TextImageTimeEmbedding,
    TextTimeEmbedding,
    TimestepEmbedding,
    Timesteps,
)
34
from .modeling_utils import ModelMixin
35
from .unet_2d_blocks import (
36
37
38
39
    CrossAttnDownBlock2D,
    CrossAttnUpBlock2D,
    DownBlock2D,
    UNetMidBlock2DCrossAttn,
Will Berman's avatar
Will Berman committed
40
    UNetMidBlock2DSimpleCrossAttn,
41
42
43
44
    UpBlock2D,
    get_down_block,
    get_up_block,
)
Patrick von Platen's avatar
Patrick von Platen committed
45
46


47
48
49
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


50
51
52
53
54
55
56
57
58
59
60
@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
    """

    sample: torch.FloatTensor


61
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
Kashif Rasul's avatar
Kashif Rasul committed
62
63
64
65
66
    r"""
    UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
    and returns sample shaped output.

    This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
67
    implements for all the models (such as downloading or saving, etc.)
Kashif Rasul's avatar
Kashif Rasul committed
68
69

    Parameters:
70
71
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
Kashif Rasul's avatar
Kashif Rasul committed
72
73
74
        in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
        out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
Suraj Patil's avatar
Suraj Patil committed
75
        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Kashif Rasul's avatar
Kashif Rasul committed
76
77
78
79
            Whether to flip the sin to cos in the time embedding.
        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            The tuple of downsample blocks to use.
Will Berman's avatar
Will Berman committed
80
        mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
81
82
            The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
            mid block layer if `None`.
Kashif Rasul's avatar
Kashif Rasul committed
83
84
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
            The tuple of upsample blocks to use.
85
86
87
        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
            Whether to include self-attention in the basic transformer blocks, see
            [`~models.attention.BasicTransformerBlock`].
Kashif Rasul's avatar
Kashif Rasul committed
88
89
90
91
92
93
94
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
95
            If `None`, it will skip the normalization and activation layers in post-processing
Kashif Rasul's avatar
Kashif Rasul committed
96
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
Sanchit Gandhi's avatar
Sanchit Gandhi committed
97
98
        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
            The dimension of the cross attention features.
William Berman's avatar
William Berman committed
99
        encoder_hid_dim (`int`, *optional*, defaults to None):
YiYi Xu's avatar
YiYi Xu committed
100
101
102
103
104
            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
            dimension to `cross_attention_dim`.
        encoder_hid_dim_type (`str`, *optional*, defaults to None):
            If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
Kashif Rasul's avatar
Kashif Rasul committed
105
        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
Will Berman's avatar
Will Berman committed
106
107
        resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
            for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
108
109
        class_embed_type (`str`, *optional*, defaults to None):
            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
110
            `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
Patrick von Platen's avatar
Patrick von Platen committed
111
112
113
        addition_embed_type (`str`, *optional*, defaults to None):
            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
            "text". "text" will use the `TextTimeEmbedding` layer.
114
115
116
        num_class_embeds (`int`, *optional*, defaults to None):
            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
            class conditioning with `class_embed_type` equal to `None`.
117
118
        time_embedding_type (`str`, *optional*, default to `positional`):
            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
Patrick von Platen's avatar
Patrick von Platen committed
119
120
        time_embedding_dim (`int`, *optional*, default to `None`):
            An optional override for the dimension of the projected time embedding.
121
122
123
        time_embedding_act_fn (`str`, *optional*, default to `None`):
            Optional activation function to use on the time embeddings only one time before they as passed to the rest
            of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
124
125
126
127
128
        timestep_post_act (`str, *optional*, default to `None`):
            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
        time_cond_proj_dim (`int`, *optional*, default to `None`):
            The dimension of `cond_proj` layer in timestep embedding.
        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
Will Berman's avatar
Will Berman committed
129
130
131
        conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
        projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
            using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
Sanchit Gandhi's avatar
Sanchit Gandhi committed
132
        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
133
134
135
136
137
138
            embeddings with the class embeddings.
        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
            `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
            default to `False`.
Kashif Rasul's avatar
Kashif Rasul committed
139
140
    """

141
142
    _supports_gradient_checkpointing = True

Patrick von Platen's avatar
Patrick von Platen committed
143
144
145
    @register_to_config
    def __init__(
        self,
Sid Sahai's avatar
Sid Sahai committed
146
147
148
149
150
151
152
153
154
155
156
157
        sample_size: Optional[int] = None,
        in_channels: int = 4,
        out_channels: int = 4,
        center_input_sample: bool = False,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        down_block_types: Tuple[str] = (
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
158
        mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
Sid Sahai's avatar
Sid Sahai committed
159
        up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
160
        only_cross_attention: Union[bool, Tuple[bool]] = False,
Sid Sahai's avatar
Sid Sahai committed
161
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
162
        layers_per_block: Union[int, Tuple[int]] = 2,
Sid Sahai's avatar
Sid Sahai committed
163
164
165
        downsample_padding: int = 1,
        mid_block_scale_factor: float = 1,
        act_fn: str = "silu",
166
        norm_num_groups: Optional[int] = 32,
Sid Sahai's avatar
Sid Sahai committed
167
        norm_eps: float = 1e-5,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
168
        cross_attention_dim: Union[int, Tuple[int]] = 1280,
William Berman's avatar
William Berman committed
169
        encoder_hid_dim: Optional[int] = None,
YiYi Xu's avatar
YiYi Xu committed
170
        encoder_hid_dim_type: Optional[str] = None,
Suraj Patil's avatar
Suraj Patil committed
171
        attention_head_dim: Union[int, Tuple[int]] = 8,
172
        dual_cross_attention: bool = False,
Suraj Patil's avatar
Suraj Patil committed
173
        use_linear_projection: bool = False,
Will Berman's avatar
Will Berman committed
174
        class_embed_type: Optional[str] = None,
Patrick von Platen's avatar
Patrick von Platen committed
175
        addition_embed_type: Optional[str] = None,
176
        num_class_embeds: Optional[int] = None,
177
        upcast_attention: bool = False,
Will Berman's avatar
Will Berman committed
178
        resnet_time_scale_shift: str = "default",
179
180
        resnet_skip_time_act: bool = False,
        resnet_out_scale_factor: int = 1.0,
181
        time_embedding_type: str = "positional",
Patrick von Platen's avatar
Patrick von Platen committed
182
        time_embedding_dim: Optional[int] = None,
183
        time_embedding_act_fn: Optional[str] = None,
184
185
186
187
        timestep_post_act: Optional[str] = None,
        time_cond_proj_dim: Optional[int] = None,
        conv_in_kernel: int = 3,
        conv_out_kernel: int = 3,
Will Berman's avatar
Will Berman committed
188
        projection_class_embeddings_input_dim: Optional[int] = None,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
189
        class_embeddings_concat: bool = False,
190
        mid_block_only_cross_attention: Optional[bool] = None,
191
        cross_attention_norm: Optional[str] = None,
Patrick von Platen's avatar
Patrick von Platen committed
192
        addition_embed_type_num_heads=64,
Patrick von Platen's avatar
Patrick von Platen committed
193
194
195
196
197
    ):
        super().__init__()

        self.sample_size = sample_size

Will Berman's avatar
Will Berman committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        # Check inputs
        if len(down_block_types) != len(up_block_types):
            raise ValueError(
                f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
            )

        if len(block_out_channels) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
            )

Sanchit Gandhi's avatar
Sanchit Gandhi committed
219
220
221
222
223
        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
            )

224
225
226
227
228
        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
            )

Patrick von Platen's avatar
Patrick von Platen committed
229
        # input
230
231
232
233
        conv_in_padding = (conv_in_kernel - 1) // 2
        self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
        )
Patrick von Platen's avatar
Patrick von Platen committed
234
235

        # time
236
        if time_embedding_type == "fourier":
Patrick von Platen's avatar
Patrick von Platen committed
237
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
238
239
240
241
242
243
244
            if time_embed_dim % 2 != 0:
                raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
            self.time_proj = GaussianFourierProjection(
                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
            )
            timestep_input_dim = time_embed_dim
        elif time_embedding_type == "positional":
Patrick von Platen's avatar
Patrick von Platen committed
245
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
246
247
248
249
250

            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
            timestep_input_dim = block_out_channels[0]
        else:
            raise ValueError(
Alexander Pivovarov's avatar
Alexander Pivovarov committed
251
                f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
252
            )
Patrick von Platen's avatar
Patrick von Platen committed
253

254
255
256
257
258
259
260
        self.time_embedding = TimestepEmbedding(
            timestep_input_dim,
            time_embed_dim,
            act_fn=act_fn,
            post_act_fn=timestep_post_act,
            cond_proj_dim=time_cond_proj_dim,
        )
Patrick von Platen's avatar
Patrick von Platen committed
261

YiYi Xu's avatar
YiYi Xu committed
262
263
        if encoder_hid_dim_type is None and encoder_hid_dim is not None:
            encoder_hid_dim_type = "text_proj"
264
            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
YiYi Xu's avatar
YiYi Xu committed
265
266
267
268
269
270
271
272
            logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")

        if encoder_hid_dim is None and encoder_hid_dim_type is not None:
            raise ValueError(
                f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
            )

        if encoder_hid_dim_type == "text_proj":
William Berman's avatar
William Berman committed
273
            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
YiYi Xu's avatar
YiYi Xu committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        elif encoder_hid_dim_type == "text_image_proj":
            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
            # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
            self.encoder_hid_proj = TextImageProjection(
                text_embed_dim=encoder_hid_dim,
                image_embed_dim=cross_attention_dim,
                cross_attention_dim=cross_attention_dim,
            )

        elif encoder_hid_dim_type is not None:
            raise ValueError(
                f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
            )
William Berman's avatar
William Berman committed
288
289
290
        else:
            self.encoder_hid_proj = None

291
        # class embedding
Will Berman's avatar
Will Berman committed
292
        if class_embed_type is None and num_class_embeds is not None:
293
            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
Will Berman's avatar
Will Berman committed
294
        elif class_embed_type == "timestep":
295
            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
Will Berman's avatar
Will Berman committed
296
297
        elif class_embed_type == "identity":
            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
Will Berman's avatar
Will Berman committed
298
299
300
301
302
303
304
305
306
307
308
309
310
        elif class_embed_type == "projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
                )
            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
            # 2. it projects from an arbitrary input dimension.
            #
            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
Sanchit Gandhi's avatar
Sanchit Gandhi committed
311
312
313
314
315
316
        elif class_embed_type == "simple_projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
                )
            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
Will Berman's avatar
Will Berman committed
317
318
        else:
            self.class_embedding = None
319

Patrick von Platen's avatar
Patrick von Platen committed
320
321
322
323
324
325
326
327
328
        if addition_embed_type == "text":
            if encoder_hid_dim is not None:
                text_time_embedding_from_dim = encoder_hid_dim
            else:
                text_time_embedding_from_dim = cross_attention_dim

            self.add_embedding = TextTimeEmbedding(
                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
            )
YiYi Xu's avatar
YiYi Xu committed
329
330
331
332
333
334
335
        elif addition_embed_type == "text_image":
            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
            # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
            self.add_embedding = TextImageTimeEmbedding(
                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
            )
Patrick von Platen's avatar
Patrick von Platen committed
336
        elif addition_embed_type is not None:
YiYi Xu's avatar
YiYi Xu committed
337
            raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
Patrick von Platen's avatar
Patrick von Platen committed
338

339
340
341
342
343
344
345
346
347
348
349
350
351
        if time_embedding_act_fn is None:
            self.time_embed_act = None
        elif time_embedding_act_fn == "swish":
            self.time_embed_act = lambda x: F.silu(x)
        elif time_embedding_act_fn == "mish":
            self.time_embed_act = nn.Mish()
        elif time_embedding_act_fn == "silu":
            self.time_embed_act = nn.SiLU()
        elif time_embedding_act_fn == "gelu":
            self.time_embed_act = nn.GELU()
        else:
            raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")

Patrick von Platen's avatar
Patrick von Platen committed
352
353
354
        self.down_blocks = nn.ModuleList([])
        self.up_blocks = nn.ModuleList([])

355
        if isinstance(only_cross_attention, bool):
356
357
358
            if mid_block_only_cross_attention is None:
                mid_block_only_cross_attention = only_cross_attention

359
360
            only_cross_attention = [only_cross_attention] * len(down_block_types)

361
362
363
        if mid_block_only_cross_attention is None:
            mid_block_only_cross_attention = False

Suraj Patil's avatar
Suraj Patil committed
364
365
366
        if isinstance(attention_head_dim, int):
            attention_head_dim = (attention_head_dim,) * len(down_block_types)

Sanchit Gandhi's avatar
Sanchit Gandhi committed
367
368
369
        if isinstance(cross_attention_dim, int):
            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)

370
371
372
        if isinstance(layers_per_block, int):
            layers_per_block = [layers_per_block] * len(down_block_types)

Sanchit Gandhi's avatar
Sanchit Gandhi committed
373
374
375
376
377
378
379
380
        if class_embeddings_concat:
            # The time embeddings are concatenated with the class embeddings. The dimension of the
            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
            # regular time embeddings
            blocks_time_embed_dim = time_embed_dim * 2
        else:
            blocks_time_embed_dim = time_embed_dim

Patrick von Platen's avatar
Patrick von Platen committed
381
382
383
384
385
386
387
388
389
        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
390
                num_layers=layers_per_block[i],
Patrick von Platen's avatar
Patrick von Platen committed
391
392
                in_channels=input_channel,
                out_channels=output_channel,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
393
                temb_channels=blocks_time_embed_dim,
Patrick von Platen's avatar
Patrick von Platen committed
394
395
396
                add_downsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
397
                resnet_groups=norm_num_groups,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
398
                cross_attention_dim=cross_attention_dim[i],
Suraj Patil's avatar
Suraj Patil committed
399
                attn_num_head_channels=attention_head_dim[i],
Patrick von Platen's avatar
Patrick von Platen committed
400
                downsample_padding=downsample_padding,
401
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
402
                use_linear_projection=use_linear_projection,
403
                only_cross_attention=only_cross_attention[i],
404
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
405
                resnet_time_scale_shift=resnet_time_scale_shift,
406
407
                resnet_skip_time_act=resnet_skip_time_act,
                resnet_out_scale_factor=resnet_out_scale_factor,
408
                cross_attention_norm=cross_attention_norm,
Patrick von Platen's avatar
Patrick von Platen committed
409
410
411
412
            )
            self.down_blocks.append(down_block)

        # mid
Will Berman's avatar
Will Berman committed
413
414
415
        if mid_block_type == "UNetMidBlock2DCrossAttn":
            self.mid_block = UNetMidBlock2DCrossAttn(
                in_channels=block_out_channels[-1],
Sanchit Gandhi's avatar
Sanchit Gandhi committed
416
                temb_channels=blocks_time_embed_dim,
Will Berman's avatar
Will Berman committed
417
418
419
420
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
                resnet_time_scale_shift=resnet_time_scale_shift,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
421
                cross_attention_dim=cross_attention_dim[-1],
Will Berman's avatar
Will Berman committed
422
423
424
425
426
427
428
429
430
                attn_num_head_channels=attention_head_dim[-1],
                resnet_groups=norm_num_groups,
                dual_cross_attention=dual_cross_attention,
                use_linear_projection=use_linear_projection,
                upcast_attention=upcast_attention,
            )
        elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
            self.mid_block = UNetMidBlock2DSimpleCrossAttn(
                in_channels=block_out_channels[-1],
Sanchit Gandhi's avatar
Sanchit Gandhi committed
431
                temb_channels=blocks_time_embed_dim,
Will Berman's avatar
Will Berman committed
432
433
434
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
435
                cross_attention_dim=cross_attention_dim[-1],
Will Berman's avatar
Will Berman committed
436
437
438
                attn_num_head_channels=attention_head_dim[-1],
                resnet_groups=norm_num_groups,
                resnet_time_scale_shift=resnet_time_scale_shift,
439
                skip_time_act=resnet_skip_time_act,
440
                only_cross_attention=mid_block_only_cross_attention,
441
                cross_attention_norm=cross_attention_norm,
Will Berman's avatar
Will Berman committed
442
            )
443
444
        elif mid_block_type is None:
            self.mid_block = None
Will Berman's avatar
Will Berman committed
445
446
        else:
            raise ValueError(f"unknown mid_block_type : {mid_block_type}")
Patrick von Platen's avatar
Patrick von Platen committed
447

448
449
450
        # count how many layers upsample the images
        self.num_upsamplers = 0

Patrick von Platen's avatar
Patrick von Platen committed
451
452
        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
Suraj Patil's avatar
Suraj Patil committed
453
        reversed_attention_head_dim = list(reversed(attention_head_dim))
454
        reversed_layers_per_block = list(reversed(layers_per_block))
Sanchit Gandhi's avatar
Sanchit Gandhi committed
455
        reversed_cross_attention_dim = list(reversed(cross_attention_dim))
456
        only_cross_attention = list(reversed(only_cross_attention))
457

Patrick von Platen's avatar
Patrick von Platen committed
458
459
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
460
461
            is_final_block = i == len(block_out_channels) - 1

Patrick von Platen's avatar
Patrick von Platen committed
462
463
464
465
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

466
467
468
469
470
471
            # add upsample block for all BUT final layer
            if not is_final_block:
                add_upsample = True
                self.num_upsamplers += 1
            else:
                add_upsample = False
Patrick von Platen's avatar
Patrick von Platen committed
472
473
474

            up_block = get_up_block(
                up_block_type,
475
                num_layers=reversed_layers_per_block[i] + 1,
Patrick von Platen's avatar
Patrick von Platen committed
476
477
478
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
479
                temb_channels=blocks_time_embed_dim,
480
                add_upsample=add_upsample,
Patrick von Platen's avatar
Patrick von Platen committed
481
482
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
483
                resnet_groups=norm_num_groups,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
484
                cross_attention_dim=reversed_cross_attention_dim[i],
Suraj Patil's avatar
Suraj Patil committed
485
                attn_num_head_channels=reversed_attention_head_dim[i],
486
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
487
                use_linear_projection=use_linear_projection,
488
                only_cross_attention=only_cross_attention[i],
489
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
490
                resnet_time_scale_shift=resnet_time_scale_shift,
491
492
                resnet_skip_time_act=resnet_skip_time_act,
                resnet_out_scale_factor=resnet_out_scale_factor,
493
                cross_attention_norm=cross_attention_norm,
Patrick von Platen's avatar
Patrick von Platen committed
494
495
496
497
498
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # out
499
500
501
502
        if norm_num_groups is not None:
            self.conv_norm_out = nn.GroupNorm(
                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
            )
503
504
505
506
507
508
509
510
511
512
513
514

            if act_fn == "swish":
                self.conv_act = lambda x: F.silu(x)
            elif act_fn == "mish":
                self.conv_act = nn.Mish()
            elif act_fn == "silu":
                self.conv_act = nn.SiLU()
            elif act_fn == "gelu":
                self.conv_act = nn.GELU()
            else:
                raise ValueError(f"Unsupported activation function: {act_fn}")

515
516
517
518
519
520
521
522
        else:
            self.conv_norm_out = None
            self.conv_act = None

        conv_out_padding = (conv_out_kernel - 1) // 2
        self.conv_out = nn.Conv2d(
            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
        )
Patrick von Platen's avatar
Patrick von Platen committed
523

524
    @property
Patrick von Platen's avatar
Patrick von Platen committed
525
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
526
527
528
529
530
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
531
        # set recursively
532
533
        processors = {}

Patrick von Platen's avatar
Patrick von Platen committed
534
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
535
536
537
538
539
540
541
542
543
544
545
546
547
            if hasattr(module, "set_processor"):
                processors[f"{name}.processor"] = module.processor

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

Patrick von Platen's avatar
Patrick von Platen committed
548
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
549
550
        r"""
        Parameters:
Patrick von Platen's avatar
Patrick von Platen committed
551
            `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
552
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
Patrick von Platen's avatar
Patrick von Platen committed
553
                of **all** `Attention` layers.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
554
            In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
555
556
557
558
559
560
561
562
563
564
565

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
566
            if hasattr(module, "set_processor"):
567
568
569
570
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))
571

572
573
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
574

575
576
        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)
577

578
579
580
581
582
583
    def set_default_attn_processor(self):
        """
        Disables custom attention processors and sets the default attention implementation.
        """
        self.set_attn_processor(AttnProcessor())

584
    def set_attention_slice(self, slice_size):
585
586
        r"""
        Enable sliced attention computation.
587

588
589
        When this option is enabled, the attention module will split the input tensor in slices, to compute attention
        in several steps. This is useful to save some memory in exchange for a small speed decrease.
590

591
592
593
        Args:
            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
                When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
Alexander Pivovarov's avatar
Alexander Pivovarov committed
594
                `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
595
596
597
598
599
                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
                must be a multiple of `slice_size`.
        """
        sliceable_head_dims = []

Alexander Pivovarov's avatar
Alexander Pivovarov committed
600
        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
601
602
603
604
            if hasattr(module, "set_attention_slice"):
                sliceable_head_dims.append(module.sliceable_head_dim)

            for child in module.children():
Alexander Pivovarov's avatar
Alexander Pivovarov committed
605
                fn_recursive_retrieve_sliceable_dims(child)
606
607
608

        # retrieve number of attention layers
        for module in self.children():
Alexander Pivovarov's avatar
Alexander Pivovarov committed
609
            fn_recursive_retrieve_sliceable_dims(module)
610

Alexander Pivovarov's avatar
Alexander Pivovarov committed
611
        num_sliceable_layers = len(sliceable_head_dims)
612
613
614
615
616
617
618

        if slice_size == "auto":
            # half the attention head size is usually a good trade-off between
            # speed and memory
            slice_size = [dim // 2 for dim in sliceable_head_dims]
        elif slice_size == "max":
            # make smallest slice possible
Alexander Pivovarov's avatar
Alexander Pivovarov committed
619
            slice_size = num_sliceable_layers * [1]
620

Alexander Pivovarov's avatar
Alexander Pivovarov committed
621
        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
622
623
624
625
626
627

        if len(slice_size) != len(sliceable_head_dims):
            raise ValueError(
                f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
                f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
            )
628

629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
        for i in range(len(slice_size)):
            size = slice_size[i]
            dim = sliceable_head_dims[i]
            if size is not None and size > dim:
                raise ValueError(f"size {size} has to be smaller or equal to {dim}.")

        # Recursively walk through all the children.
        # Any children which exposes the set_attention_slice method
        # gets the message
        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
            if hasattr(module, "set_attention_slice"):
                module.set_attention_slice(slice_size.pop())

            for child in module.children():
                fn_recursive_set_attention_slice(child, slice_size)

        reversed_slice_size = list(reversed(slice_size))
        for module in self.children():
            fn_recursive_set_attention_slice(module, reversed_slice_size)
648

649
650
651
652
    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
            module.gradient_checkpointing = value

Patrick von Platen's avatar
Patrick von Platen committed
653
654
655
656
657
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
658
        class_labels: Optional[torch.Tensor] = None,
659
        timestep_cond: Optional[torch.Tensor] = None,
Will Berman's avatar
Will Berman committed
660
        attention_mask: Optional[torch.Tensor] = None,
661
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
YiYi Xu's avatar
YiYi Xu committed
662
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
663
664
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
665
        encoder_attention_mask: Optional[torch.Tensor] = None,
666
667
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
668
        r"""
Kashif Rasul's avatar
Kashif Rasul committed
669
670
        Args:
            sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
671
            timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
672
            encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
673
674
675
676
            encoder_attention_mask (`torch.Tensor`):
                (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
                discard. Mask will be converted into a bias, which adds large negative values to attention scores
                corresponding to "discard" tokens.
Kashif Rasul's avatar
Kashif Rasul committed
677
678
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
679
            cross_attention_kwargs (`dict`, *optional*):
Patrick von Platen's avatar
Patrick von Platen committed
680
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
681
682
                `self.processor` in
                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
YiYi Xu's avatar
YiYi Xu committed
683
684
685
686
            added_cond_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
                embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
                `addition_embed_type` for more information.
Kashif Rasul's avatar
Kashif Rasul committed
687
688
689
690
691
692

        Returns:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
        """
693
        # By default samples have to be AT least a multiple of the overall upsampling factor.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
694
        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
695
696
697
698
699
700
701
702
703
704
705
706
        # However, the upsampling interpolation output size can be forced to fit any upsampling size
        # on the fly if necessary.
        default_overall_up_factor = 2**self.num_upsamplers

        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
        forward_upsample_size = False
        upsample_size = None

        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
            logger.info("Forward upsample size to force interpolation output size.")
            forward_upsample_size = True

707
708
709
710
711
712
713
714
        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
        # expects mask of shape:
        #   [batch, key_tokens]
        # adds singleton query_tokens dimension:
        #   [batch,                    1, key_tokens]
        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
Will Berman's avatar
Will Berman committed
715
        if attention_mask is not None:
716
717
718
719
            # assume that mask is expressed as:
            #   (1 = keep,      0 = discard)
            # convert mask into a bias that can be added to attention scores:
            #       (keep = +0,     discard = -10000.0)
Will Berman's avatar
Will Berman committed
720
721
722
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
            attention_mask = attention_mask.unsqueeze(1)

723
724
725
726
727
        # convert encoder_attention_mask to a bias the same way we do for attention_mask
        if encoder_attention_mask is not None:
            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

Patrick von Platen's avatar
Patrick von Platen committed
728
729
730
731
732
733
734
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
735
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
736
737
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
Patrick von Platen's avatar
Patrick von Platen committed
738
            if isinstance(timestep, float):
739
740
741
742
743
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
744
            timesteps = timesteps[None].to(sample.device)
Patrick von Platen's avatar
Patrick von Platen committed
745

746
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
747
        timesteps = timesteps.expand(sample.shape[0])
748

Patrick von Platen's avatar
Patrick von Platen committed
749
        t_emb = self.time_proj(timesteps)
750

751
        # `Timesteps` does not contain any weights and will always return f32 tensors
752
753
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
754
        t_emb = t_emb.to(dtype=sample.dtype)
755
756

        emb = self.time_embedding(t_emb, timestep_cond)
Patrick von Platen's avatar
Patrick von Platen committed
757

Will Berman's avatar
Will Berman committed
758
        if self.class_embedding is not None:
759
760
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")
Will Berman's avatar
Will Berman committed
761
762
763
764

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

765
766
767
768
                # `Timesteps` does not contain any weights and will always return f32 tensors
                # there might be better ways to encapsulate this.
                class_labels = class_labels.to(dtype=sample.dtype)

769
            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
Sanchit Gandhi's avatar
Sanchit Gandhi committed
770
771
772
773
774

            if self.config.class_embeddings_concat:
                emb = torch.cat([emb, class_emb], dim=-1)
            else:
                emb = emb + class_emb
775

Patrick von Platen's avatar
Patrick von Platen committed
776
777
778
        if self.config.addition_embed_type == "text":
            aug_emb = self.add_embedding(encoder_hidden_states)
            emb = emb + aug_emb
YiYi Xu's avatar
YiYi Xu committed
779
780
781
782
783
784
785
786
787
788
789
790
        elif self.config.addition_embed_type == "text_image":
            # Kadinsky 2.1 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
                )

            image_embs = added_cond_kwargs.get("image_embeds")
            text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)

            aug_emb = self.add_embedding(text_embs, image_embs)
            emb = emb + aug_emb
Patrick von Platen's avatar
Patrick von Platen committed
791

792
793
794
        if self.time_embed_act is not None:
            emb = self.time_embed_act(emb)

YiYi Xu's avatar
YiYi Xu committed
795
        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
William Berman's avatar
William Berman committed
796
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
YiYi Xu's avatar
YiYi Xu committed
797
798
799
800
801
802
803
804
805
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
            # Kadinsky 2.1 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )

            image_embeds = added_cond_kwargs.get("image_embeds")
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
William Berman's avatar
William Berman committed
806

Patrick von Platen's avatar
Patrick von Platen committed
807
808
809
810
811
812
        # 2. pre-process
        sample = self.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
813
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
814
                sample, res_samples = downsample_block(
815
816
817
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
Will Berman's avatar
Will Berman committed
818
                    attention_mask=attention_mask,
819
                    cross_attention_kwargs=cross_attention_kwargs,
820
                    encoder_attention_mask=encoder_attention_mask,
Patrick von Platen's avatar
Patrick von Platen committed
821
822
823
824
825
826
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

827
828
829
830
831
832
        if down_block_additional_residuals is not None:
            new_down_block_res_samples = ()

            for down_block_res_sample, down_block_additional_residual in zip(
                down_block_res_samples, down_block_additional_residuals
            ):
833
                down_block_res_sample = down_block_res_sample + down_block_additional_residual
834
                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
835
836
837

            down_block_res_samples = new_down_block_res_samples

Patrick von Platen's avatar
Patrick von Platen committed
838
        # 4. mid
839
840
841
842
843
844
845
        if self.mid_block is not None:
            sample = self.mid_block(
                sample,
                emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                cross_attention_kwargs=cross_attention_kwargs,
846
                encoder_attention_mask=encoder_attention_mask,
847
            )
Patrick von Platen's avatar
Patrick von Platen committed
848

849
        if mid_block_additional_residual is not None:
850
            sample = sample + mid_block_additional_residual
851

Patrick von Platen's avatar
Patrick von Platen committed
852
        # 5. up
853
854
855
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

Patrick von Platen's avatar
Patrick von Platen committed
856
857
858
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

859
860
861
862
863
            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

864
            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
865
866
867
868
869
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
870
                    cross_attention_kwargs=cross_attention_kwargs,
871
                    upsample_size=upsample_size,
Will Berman's avatar
Will Berman committed
872
                    attention_mask=attention_mask,
873
                    encoder_attention_mask=encoder_attention_mask,
Patrick von Platen's avatar
Patrick von Platen committed
874
875
                )
            else:
876
877
878
                sample = upsample_block(
                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
                )
879

Patrick von Platen's avatar
Patrick von Platen committed
880
        # 6. post-process
881
882
883
        if self.conv_norm_out:
            sample = self.conv_norm_out(sample)
            sample = self.conv_act(sample)
Patrick von Platen's avatar
Patrick von Platen committed
884
885
        sample = self.conv_out(sample)

886
887
        if not return_dict:
            return (sample,)
Patrick von Platen's avatar
Patrick von Platen committed
888

889
        return UNet2DConditionOutput(sample=sample)