unet_2d_condition.py 49.2 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from dataclasses import dataclass
15
from typing import Any, Dict, List, Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
16
17
18

import torch
import torch.nn as nn
19
import torch.utils.checkpoint
Patrick von Platen's avatar
Patrick von Platen committed
20
21

from ..configuration_utils import ConfigMixin, register_to_config
22
from ..loaders import UNet2DConditionLoadersMixin
23
from ..utils import BaseOutput, logging
24
from .activations import get_activation
25
from .attention_processor import AttentionProcessor, AttnProcessor
YiYi Xu's avatar
YiYi Xu committed
26
27
from .embeddings import (
    GaussianFourierProjection,
YiYi Xu's avatar
YiYi Xu committed
28
29
30
    ImageHintTimeEmbedding,
    ImageProjection,
    ImageTimeEmbedding,
YiYi Xu's avatar
YiYi Xu committed
31
32
33
34
35
36
    TextImageProjection,
    TextImageTimeEmbedding,
    TextTimeEmbedding,
    TimestepEmbedding,
    Timesteps,
)
37
from .modeling_utils import ModelMixin
38
from .unet_2d_blocks import (
39
40
41
42
    CrossAttnDownBlock2D,
    CrossAttnUpBlock2D,
    DownBlock2D,
    UNetMidBlock2DCrossAttn,
Will Berman's avatar
Will Berman committed
43
    UNetMidBlock2DSimpleCrossAttn,
44
45
46
47
    UpBlock2D,
    get_down_block,
    get_up_block,
)
Patrick von Platen's avatar
Patrick von Platen committed
48
49


50
51
52
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


53
54
55
@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
Steven Liu's avatar
Steven Liu committed
56
57
    The output of [`UNet2DConditionModel`].

58
59
    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Steven Liu's avatar
Steven Liu committed
60
            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
61
62
    """

63
    sample: torch.FloatTensor = None
64
65


66
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
Kashif Rasul's avatar
Kashif Rasul committed
67
    r"""
Steven Liu's avatar
Steven Liu committed
68
69
    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
    shaped output.
Kashif Rasul's avatar
Kashif Rasul committed
70

Steven Liu's avatar
Steven Liu committed
71
72
    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).
Kashif Rasul's avatar
Kashif Rasul committed
73
74

    Parameters:
75
76
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
Steven Liu's avatar
Steven Liu committed
77
78
        in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
Kashif Rasul's avatar
Kashif Rasul committed
79
        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
Suraj Patil's avatar
Suraj Patil committed
80
        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Kashif Rasul's avatar
Kashif Rasul committed
81
82
83
84
            Whether to flip the sin to cos in the time embedding.
        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            The tuple of downsample blocks to use.
Will Berman's avatar
Will Berman committed
85
        mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Steven Liu's avatar
Steven Liu committed
86
87
88
            Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
Kashif Rasul's avatar
Kashif Rasul committed
89
            The tuple of upsample blocks to use.
90
91
92
        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
            Whether to include self-attention in the basic transformer blocks, see
            [`~models.attention.BasicTransformerBlock`].
Kashif Rasul's avatar
Kashif Rasul committed
93
94
95
96
97
98
99
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
Steven Liu's avatar
Steven Liu committed
100
            If `None`, normalization and activation layers is skipped in post-processing.
Kashif Rasul's avatar
Kashif Rasul committed
101
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
Sanchit Gandhi's avatar
Sanchit Gandhi committed
102
103
        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
            The dimension of the cross attention features.
104
105
106
107
108
        transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
        encoder_hid_dim (`int`, *optional*, defaults to None):
YiYi Xu's avatar
YiYi Xu committed
109
110
            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
            dimension to `cross_attention_dim`.
Steven Liu's avatar
Steven Liu committed
111
112
        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
YiYi Xu's avatar
YiYi Xu committed
113
            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
Kashif Rasul's avatar
Kashif Rasul committed
114
        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
115
116
        num_attention_heads (`int`, *optional*):
            The number of attention heads. If not defined, defaults to `attention_head_dim`
Will Berman's avatar
Will Berman committed
117
        resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
Steven Liu's avatar
Steven Liu committed
118
119
            for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
        class_embed_type (`str`, *optional*, defaults to `None`):
120
            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
121
            `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
Steven Liu's avatar
Steven Liu committed
122
        addition_embed_type (`str`, *optional*, defaults to `None`):
Patrick von Platen's avatar
Patrick von Platen committed
123
124
            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
            "text". "text" will use the `TextTimeEmbedding` layer.
125
126
        addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
            Dimension for the timestep embeddings.
Steven Liu's avatar
Steven Liu committed
127
        num_class_embeds (`int`, *optional*, defaults to `None`):
128
129
            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
            class conditioning with `class_embed_type` equal to `None`.
Steven Liu's avatar
Steven Liu committed
130
        time_embedding_type (`str`, *optional*, defaults to `positional`):
131
            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
Steven Liu's avatar
Steven Liu committed
132
        time_embedding_dim (`int`, *optional*, defaults to `None`):
Patrick von Platen's avatar
Patrick von Platen committed
133
            An optional override for the dimension of the projected time embedding.
Steven Liu's avatar
Steven Liu committed
134
135
136
137
        time_embedding_act_fn (`str`, *optional*, defaults to `None`):
            Optional activation function to use only once on the time embeddings before they are passed to the rest of
            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
        timestep_post_act (`str`, *optional*, defaults to `None`):
138
            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
Steven Liu's avatar
Steven Liu committed
139
140
        time_cond_proj_dim (`int`, *optional*, defaults to `None`):
            The dimension of `cond_proj` layer in the timestep embedding.
141
        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
Will Berman's avatar
Will Berman committed
142
143
        conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
        projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
Steven Liu's avatar
Steven Liu committed
144
            `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
Sanchit Gandhi's avatar
Sanchit Gandhi committed
145
        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
146
147
148
            embeddings with the class embeddings.
        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
Steven Liu's avatar
Steven Liu committed
149
150
151
            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
            `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
            otherwise.
Kashif Rasul's avatar
Kashif Rasul committed
152
153
    """

154
155
    _supports_gradient_checkpointing = True

Patrick von Platen's avatar
Patrick von Platen committed
156
157
158
    @register_to_config
    def __init__(
        self,
Sid Sahai's avatar
Sid Sahai committed
159
160
161
162
163
164
165
166
167
168
169
170
        sample_size: Optional[int] = None,
        in_channels: int = 4,
        out_channels: int = 4,
        center_input_sample: bool = False,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        down_block_types: Tuple[str] = (
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
171
        mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
Sid Sahai's avatar
Sid Sahai committed
172
        up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
173
        only_cross_attention: Union[bool, Tuple[bool]] = False,
Sid Sahai's avatar
Sid Sahai committed
174
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
175
        layers_per_block: Union[int, Tuple[int]] = 2,
Sid Sahai's avatar
Sid Sahai committed
176
177
178
        downsample_padding: int = 1,
        mid_block_scale_factor: float = 1,
        act_fn: str = "silu",
179
        norm_num_groups: Optional[int] = 32,
Sid Sahai's avatar
Sid Sahai committed
180
        norm_eps: float = 1e-5,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
181
        cross_attention_dim: Union[int, Tuple[int]] = 1280,
182
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
William Berman's avatar
William Berman committed
183
        encoder_hid_dim: Optional[int] = None,
YiYi Xu's avatar
YiYi Xu committed
184
        encoder_hid_dim_type: Optional[str] = None,
Suraj Patil's avatar
Suraj Patil committed
185
        attention_head_dim: Union[int, Tuple[int]] = 8,
186
        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
187
        dual_cross_attention: bool = False,
Suraj Patil's avatar
Suraj Patil committed
188
        use_linear_projection: bool = False,
Will Berman's avatar
Will Berman committed
189
        class_embed_type: Optional[str] = None,
Patrick von Platen's avatar
Patrick von Platen committed
190
        addition_embed_type: Optional[str] = None,
191
        addition_time_embed_dim: Optional[int] = None,
192
        num_class_embeds: Optional[int] = None,
193
        upcast_attention: bool = False,
Will Berman's avatar
Will Berman committed
194
        resnet_time_scale_shift: str = "default",
195
196
        resnet_skip_time_act: bool = False,
        resnet_out_scale_factor: int = 1.0,
197
        time_embedding_type: str = "positional",
Patrick von Platen's avatar
Patrick von Platen committed
198
        time_embedding_dim: Optional[int] = None,
199
        time_embedding_act_fn: Optional[str] = None,
200
201
202
203
        timestep_post_act: Optional[str] = None,
        time_cond_proj_dim: Optional[int] = None,
        conv_in_kernel: int = 3,
        conv_out_kernel: int = 3,
Will Berman's avatar
Will Berman committed
204
        projection_class_embeddings_input_dim: Optional[int] = None,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
205
        class_embeddings_concat: bool = False,
206
        mid_block_only_cross_attention: Optional[bool] = None,
207
        cross_attention_norm: Optional[str] = None,
Patrick von Platen's avatar
Patrick von Platen committed
208
        addition_embed_type_num_heads=64,
Patrick von Platen's avatar
Patrick von Platen committed
209
210
211
212
213
    ):
        super().__init__()

        self.sample_size = sample_size

214
215
216
217
218
219
220
221
        # If `num_attention_heads` is not defined (which is the case for most models)
        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
        # The reason for this behavior is to correct for incorrectly named variables that were introduced
        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
        # which is why we correct for the naming here.
        num_attention_heads = num_attention_heads or attention_head_dim

Will Berman's avatar
Will Berman committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
        # Check inputs
        if len(down_block_types) != len(up_block_types):
            raise ValueError(
                f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
            )

        if len(block_out_channels) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
            )

238
239
240
241
242
        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
            )

Will Berman's avatar
Will Berman committed
243
244
245
246
247
        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
            )

Sanchit Gandhi's avatar
Sanchit Gandhi committed
248
249
250
251
252
        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
            )

253
254
255
256
257
        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
            )

Patrick von Platen's avatar
Patrick von Platen committed
258
        # input
259
260
261
262
        conv_in_padding = (conv_in_kernel - 1) // 2
        self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
        )
Patrick von Platen's avatar
Patrick von Platen committed
263
264

        # time
265
        if time_embedding_type == "fourier":
Patrick von Platen's avatar
Patrick von Platen committed
266
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
267
268
269
270
271
272
273
            if time_embed_dim % 2 != 0:
                raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
            self.time_proj = GaussianFourierProjection(
                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
            )
            timestep_input_dim = time_embed_dim
        elif time_embedding_type == "positional":
Patrick von Platen's avatar
Patrick von Platen committed
274
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
275
276
277
278
279

            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
            timestep_input_dim = block_out_channels[0]
        else:
            raise ValueError(
Alexander Pivovarov's avatar
Alexander Pivovarov committed
280
                f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
281
            )
Patrick von Platen's avatar
Patrick von Platen committed
282

283
284
285
286
287
288
289
        self.time_embedding = TimestepEmbedding(
            timestep_input_dim,
            time_embed_dim,
            act_fn=act_fn,
            post_act_fn=timestep_post_act,
            cond_proj_dim=time_cond_proj_dim,
        )
Patrick von Platen's avatar
Patrick von Platen committed
290

YiYi Xu's avatar
YiYi Xu committed
291
292
        if encoder_hid_dim_type is None and encoder_hid_dim is not None:
            encoder_hid_dim_type = "text_proj"
293
            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
YiYi Xu's avatar
YiYi Xu committed
294
295
296
297
298
299
300
301
            logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")

        if encoder_hid_dim is None and encoder_hid_dim_type is not None:
            raise ValueError(
                f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
            )

        if encoder_hid_dim_type == "text_proj":
William Berman's avatar
William Berman committed
302
            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
YiYi Xu's avatar
YiYi Xu committed
303
304
305
306
307
308
309
310
311
        elif encoder_hid_dim_type == "text_image_proj":
            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
            # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
            self.encoder_hid_proj = TextImageProjection(
                text_embed_dim=encoder_hid_dim,
                image_embed_dim=cross_attention_dim,
                cross_attention_dim=cross_attention_dim,
            )
YiYi Xu's avatar
YiYi Xu committed
312
313
314
315
316
317
        elif encoder_hid_dim_type == "image_proj":
            # Kandinsky 2.2
            self.encoder_hid_proj = ImageProjection(
                image_embed_dim=encoder_hid_dim,
                cross_attention_dim=cross_attention_dim,
            )
YiYi Xu's avatar
YiYi Xu committed
318
319
320
321
        elif encoder_hid_dim_type is not None:
            raise ValueError(
                f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
            )
William Berman's avatar
William Berman committed
322
323
324
        else:
            self.encoder_hid_proj = None

325
        # class embedding
Will Berman's avatar
Will Berman committed
326
        if class_embed_type is None and num_class_embeds is not None:
327
            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
Will Berman's avatar
Will Berman committed
328
        elif class_embed_type == "timestep":
329
            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
Will Berman's avatar
Will Berman committed
330
331
        elif class_embed_type == "identity":
            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
Will Berman's avatar
Will Berman committed
332
333
334
335
336
337
338
339
340
341
342
343
344
        elif class_embed_type == "projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
                )
            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
            # 2. it projects from an arbitrary input dimension.
            #
            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
Sanchit Gandhi's avatar
Sanchit Gandhi committed
345
346
347
348
349
350
        elif class_embed_type == "simple_projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
                )
            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
Will Berman's avatar
Will Berman committed
351
352
        else:
            self.class_embedding = None
353

Patrick von Platen's avatar
Patrick von Platen committed
354
355
356
357
358
359
360
361
362
        if addition_embed_type == "text":
            if encoder_hid_dim is not None:
                text_time_embedding_from_dim = encoder_hid_dim
            else:
                text_time_embedding_from_dim = cross_attention_dim

            self.add_embedding = TextTimeEmbedding(
                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
            )
YiYi Xu's avatar
YiYi Xu committed
363
364
365
366
367
368
369
        elif addition_embed_type == "text_image":
            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
            # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
            self.add_embedding = TextImageTimeEmbedding(
                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
            )
370
371
372
        elif addition_embed_type == "text_time":
            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
YiYi Xu's avatar
YiYi Xu committed
373
374
375
376
377
378
        elif addition_embed_type == "image":
            # Kandinsky 2.2
            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
        elif addition_embed_type == "image_hint":
            # Kandinsky 2.2 ControlNet
            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
Patrick von Platen's avatar
Patrick von Platen committed
379
        elif addition_embed_type is not None:
YiYi Xu's avatar
YiYi Xu committed
380
            raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
Patrick von Platen's avatar
Patrick von Platen committed
381

382
383
384
        if time_embedding_act_fn is None:
            self.time_embed_act = None
        else:
385
            self.time_embed_act = get_activation(time_embedding_act_fn)
386

Patrick von Platen's avatar
Patrick von Platen committed
387
388
389
        self.down_blocks = nn.ModuleList([])
        self.up_blocks = nn.ModuleList([])

390
        if isinstance(only_cross_attention, bool):
391
392
393
            if mid_block_only_cross_attention is None:
                mid_block_only_cross_attention = only_cross_attention

394
395
            only_cross_attention = [only_cross_attention] * len(down_block_types)

396
397
398
        if mid_block_only_cross_attention is None:
            mid_block_only_cross_attention = False

399
400
401
        if isinstance(num_attention_heads, int):
            num_attention_heads = (num_attention_heads,) * len(down_block_types)

Suraj Patil's avatar
Suraj Patil committed
402
403
404
        if isinstance(attention_head_dim, int):
            attention_head_dim = (attention_head_dim,) * len(down_block_types)

Sanchit Gandhi's avatar
Sanchit Gandhi committed
405
406
407
        if isinstance(cross_attention_dim, int):
            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)

408
409
410
        if isinstance(layers_per_block, int):
            layers_per_block = [layers_per_block] * len(down_block_types)

411
412
413
        if isinstance(transformer_layers_per_block, int):
            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)

Sanchit Gandhi's avatar
Sanchit Gandhi committed
414
415
416
417
418
419
420
421
        if class_embeddings_concat:
            # The time embeddings are concatenated with the class embeddings. The dimension of the
            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
            # regular time embeddings
            blocks_time_embed_dim = time_embed_dim * 2
        else:
            blocks_time_embed_dim = time_embed_dim

Patrick von Platen's avatar
Patrick von Platen committed
422
423
424
425
426
427
428
429
430
        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
431
                num_layers=layers_per_block[i],
432
                transformer_layers_per_block=transformer_layers_per_block[i],
Patrick von Platen's avatar
Patrick von Platen committed
433
434
                in_channels=input_channel,
                out_channels=output_channel,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
435
                temb_channels=blocks_time_embed_dim,
Patrick von Platen's avatar
Patrick von Platen committed
436
437
438
                add_downsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
439
                resnet_groups=norm_num_groups,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
440
                cross_attention_dim=cross_attention_dim[i],
441
                num_attention_heads=num_attention_heads[i],
Patrick von Platen's avatar
Patrick von Platen committed
442
                downsample_padding=downsample_padding,
443
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
444
                use_linear_projection=use_linear_projection,
445
                only_cross_attention=only_cross_attention[i],
446
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
447
                resnet_time_scale_shift=resnet_time_scale_shift,
448
449
                resnet_skip_time_act=resnet_skip_time_act,
                resnet_out_scale_factor=resnet_out_scale_factor,
450
                cross_attention_norm=cross_attention_norm,
451
                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
Patrick von Platen's avatar
Patrick von Platen committed
452
453
454
455
            )
            self.down_blocks.append(down_block)

        # mid
Will Berman's avatar
Will Berman committed
456
457
        if mid_block_type == "UNetMidBlock2DCrossAttn":
            self.mid_block = UNetMidBlock2DCrossAttn(
458
                transformer_layers_per_block=transformer_layers_per_block[-1],
Will Berman's avatar
Will Berman committed
459
                in_channels=block_out_channels[-1],
Sanchit Gandhi's avatar
Sanchit Gandhi committed
460
                temb_channels=blocks_time_embed_dim,
Will Berman's avatar
Will Berman committed
461
462
463
464
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
                resnet_time_scale_shift=resnet_time_scale_shift,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
465
                cross_attention_dim=cross_attention_dim[-1],
466
                num_attention_heads=num_attention_heads[-1],
Will Berman's avatar
Will Berman committed
467
468
469
470
471
472
473
474
                resnet_groups=norm_num_groups,
                dual_cross_attention=dual_cross_attention,
                use_linear_projection=use_linear_projection,
                upcast_attention=upcast_attention,
            )
        elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
            self.mid_block = UNetMidBlock2DSimpleCrossAttn(
                in_channels=block_out_channels[-1],
Sanchit Gandhi's avatar
Sanchit Gandhi committed
475
                temb_channels=blocks_time_embed_dim,
Will Berman's avatar
Will Berman committed
476
477
478
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
479
                cross_attention_dim=cross_attention_dim[-1],
480
                attention_head_dim=attention_head_dim[-1],
Will Berman's avatar
Will Berman committed
481
482
                resnet_groups=norm_num_groups,
                resnet_time_scale_shift=resnet_time_scale_shift,
483
                skip_time_act=resnet_skip_time_act,
484
                only_cross_attention=mid_block_only_cross_attention,
485
                cross_attention_norm=cross_attention_norm,
Will Berman's avatar
Will Berman committed
486
            )
487
488
        elif mid_block_type is None:
            self.mid_block = None
Will Berman's avatar
Will Berman committed
489
490
        else:
            raise ValueError(f"unknown mid_block_type : {mid_block_type}")
Patrick von Platen's avatar
Patrick von Platen committed
491

492
493
494
        # count how many layers upsample the images
        self.num_upsamplers = 0

Patrick von Platen's avatar
Patrick von Platen committed
495
496
        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
497
        reversed_num_attention_heads = list(reversed(num_attention_heads))
498
        reversed_layers_per_block = list(reversed(layers_per_block))
Sanchit Gandhi's avatar
Sanchit Gandhi committed
499
        reversed_cross_attention_dim = list(reversed(cross_attention_dim))
500
        reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
501
        only_cross_attention = list(reversed(only_cross_attention))
502

Patrick von Platen's avatar
Patrick von Platen committed
503
504
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
505
506
            is_final_block = i == len(block_out_channels) - 1

Patrick von Platen's avatar
Patrick von Platen committed
507
508
509
510
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

511
512
513
514
515
516
            # add upsample block for all BUT final layer
            if not is_final_block:
                add_upsample = True
                self.num_upsamplers += 1
            else:
                add_upsample = False
Patrick von Platen's avatar
Patrick von Platen committed
517
518
519

            up_block = get_up_block(
                up_block_type,
520
                num_layers=reversed_layers_per_block[i] + 1,
521
                transformer_layers_per_block=reversed_transformer_layers_per_block[i],
Patrick von Platen's avatar
Patrick von Platen committed
522
523
524
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
525
                temb_channels=blocks_time_embed_dim,
526
                add_upsample=add_upsample,
Patrick von Platen's avatar
Patrick von Platen committed
527
528
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
529
                resnet_groups=norm_num_groups,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
530
                cross_attention_dim=reversed_cross_attention_dim[i],
531
                num_attention_heads=reversed_num_attention_heads[i],
532
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
533
                use_linear_projection=use_linear_projection,
534
                only_cross_attention=only_cross_attention[i],
535
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
536
                resnet_time_scale_shift=resnet_time_scale_shift,
537
538
                resnet_skip_time_act=resnet_skip_time_act,
                resnet_out_scale_factor=resnet_out_scale_factor,
539
                cross_attention_norm=cross_attention_norm,
540
                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
Patrick von Platen's avatar
Patrick von Platen committed
541
542
543
544
545
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # out
546
547
548
549
        if norm_num_groups is not None:
            self.conv_norm_out = nn.GroupNorm(
                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
            )
550

551
            self.conv_act = get_activation(act_fn)
552

553
554
555
556
557
558
559
560
        else:
            self.conv_norm_out = None
            self.conv_act = None

        conv_out_padding = (conv_out_kernel - 1) // 2
        self.conv_out = nn.Conv2d(
            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
        )
Patrick von Platen's avatar
Patrick von Platen committed
561

562
    @property
Patrick von Platen's avatar
Patrick von Platen committed
563
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
564
565
566
567
568
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
569
        # set recursively
570
571
        processors = {}

Patrick von Platen's avatar
Patrick von Platen committed
572
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
573
574
575
576
577
578
579
580
581
582
583
584
585
            if hasattr(module, "set_processor"):
                processors[f"{name}.processor"] = module.processor

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

Patrick von Platen's avatar
Patrick von Platen committed
586
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
587
        r"""
Steven Liu's avatar
Steven Liu committed
588
589
        Sets the attention processor to use to compute attention.

590
        Parameters:
Steven Liu's avatar
Steven Liu committed
591
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
592
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
Steven Liu's avatar
Steven Liu committed
593
594
595
596
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.
597
598
599
600
601
602
603
604
605
606
607

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
608
            if hasattr(module, "set_processor"):
609
610
611
612
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))
613

614
615
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
616

617
618
        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)
619

620
621
622
623
624
625
    def set_default_attn_processor(self):
        """
        Disables custom attention processors and sets the default attention implementation.
        """
        self.set_attn_processor(AttnProcessor())

626
    def set_attention_slice(self, slice_size):
627
628
        r"""
        Enable sliced attention computation.
629

Steven Liu's avatar
Steven Liu committed
630
631
        When this option is enabled, the attention module splits the input tensor in slices to compute attention in
        several steps. This is useful for saving some memory in exchange for a small decrease in speed.
632

633
634
        Args:
            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
Steven Liu's avatar
Steven Liu committed
635
636
637
638
                When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
                `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
                must be a multiple of `slice_size`.
639
640
641
        """
        sliceable_head_dims = []

Alexander Pivovarov's avatar
Alexander Pivovarov committed
642
        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
643
644
645
646
            if hasattr(module, "set_attention_slice"):
                sliceable_head_dims.append(module.sliceable_head_dim)

            for child in module.children():
Alexander Pivovarov's avatar
Alexander Pivovarov committed
647
                fn_recursive_retrieve_sliceable_dims(child)
648
649
650

        # retrieve number of attention layers
        for module in self.children():
Alexander Pivovarov's avatar
Alexander Pivovarov committed
651
            fn_recursive_retrieve_sliceable_dims(module)
652

Alexander Pivovarov's avatar
Alexander Pivovarov committed
653
        num_sliceable_layers = len(sliceable_head_dims)
654
655
656
657
658
659
660

        if slice_size == "auto":
            # half the attention head size is usually a good trade-off between
            # speed and memory
            slice_size = [dim // 2 for dim in sliceable_head_dims]
        elif slice_size == "max":
            # make smallest slice possible
Alexander Pivovarov's avatar
Alexander Pivovarov committed
661
            slice_size = num_sliceable_layers * [1]
662

Alexander Pivovarov's avatar
Alexander Pivovarov committed
663
        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
664
665
666
667
668
669

        if len(slice_size) != len(sliceable_head_dims):
            raise ValueError(
                f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
                f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
            )
670

671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
        for i in range(len(slice_size)):
            size = slice_size[i]
            dim = sliceable_head_dims[i]
            if size is not None and size > dim:
                raise ValueError(f"size {size} has to be smaller or equal to {dim}.")

        # Recursively walk through all the children.
        # Any children which exposes the set_attention_slice method
        # gets the message
        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
            if hasattr(module, "set_attention_slice"):
                module.set_attention_slice(slice_size.pop())

            for child in module.children():
                fn_recursive_set_attention_slice(child, slice_size)

        reversed_slice_size = list(reversed(slice_size))
        for module in self.children():
            fn_recursive_set_attention_slice(module, reversed_slice_size)
690

691
692
693
694
    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
            module.gradient_checkpointing = value

Patrick von Platen's avatar
Patrick von Platen committed
695
696
697
698
699
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
700
        class_labels: Optional[torch.Tensor] = None,
701
        timestep_cond: Optional[torch.Tensor] = None,
Will Berman's avatar
Will Berman committed
702
        attention_mask: Optional[torch.Tensor] = None,
703
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
YiYi Xu's avatar
YiYi Xu committed
704
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
705
706
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
707
        encoder_attention_mask: Optional[torch.Tensor] = None,
708
709
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
710
        r"""
Steven Liu's avatar
Steven Liu committed
711
712
        The [`UNet2DConditionModel`] forward method.

Kashif Rasul's avatar
Kashif Rasul committed
713
        Args:
Steven Liu's avatar
Steven Liu committed
714
715
716
717
718
            sample (`torch.FloatTensor`):
                The noisy input tensor with the following shape `(batch, channel, height, width)`.
            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
            encoder_hidden_states (`torch.FloatTensor`):
                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
719
            encoder_attention_mask (`torch.Tensor`):
Steven Liu's avatar
Steven Liu committed
720
721
722
                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
                which adds large negative values to the attention scores corresponding to "discard" tokens.
Kashif Rasul's avatar
Kashif Rasul committed
723
            return_dict (`bool`, *optional*, defaults to `True`):
Steven Liu's avatar
Steven Liu committed
724
725
                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
                tuple.
726
            cross_attention_kwargs (`dict`, *optional*):
Steven Liu's avatar
Steven Liu committed
727
                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
728
729
730
            added_cond_kwargs: (`dict`, *optional*):
                A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
                are passed along to the UNet blocks.
Kashif Rasul's avatar
Kashif Rasul committed
731
732
733

        Returns:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
Steven Liu's avatar
Steven Liu committed
734
735
                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
                a `tuple` is returned where the first element is the sample tensor.
Kashif Rasul's avatar
Kashif Rasul committed
736
        """
737
        # By default samples have to be AT least a multiple of the overall upsampling factor.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
738
        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
739
740
741
742
743
744
745
746
747
748
749
750
        # However, the upsampling interpolation output size can be forced to fit any upsampling size
        # on the fly if necessary.
        default_overall_up_factor = 2**self.num_upsamplers

        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
        forward_upsample_size = False
        upsample_size = None

        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
            logger.info("Forward upsample size to force interpolation output size.")
            forward_upsample_size = True

751
752
753
754
755
756
757
758
        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
        # expects mask of shape:
        #   [batch, key_tokens]
        # adds singleton query_tokens dimension:
        #   [batch,                    1, key_tokens]
        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
Will Berman's avatar
Will Berman committed
759
        if attention_mask is not None:
760
761
762
763
            # assume that mask is expressed as:
            #   (1 = keep,      0 = discard)
            # convert mask into a bias that can be added to attention scores:
            #       (keep = +0,     discard = -10000.0)
Will Berman's avatar
Will Berman committed
764
765
766
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
            attention_mask = attention_mask.unsqueeze(1)

767
768
769
770
771
        # convert encoder_attention_mask to a bias the same way we do for attention_mask
        if encoder_attention_mask is not None:
            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

Patrick von Platen's avatar
Patrick von Platen committed
772
773
774
775
776
777
778
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
779
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
780
781
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
Patrick von Platen's avatar
Patrick von Platen committed
782
            if isinstance(timestep, float):
783
784
785
786
787
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
788
            timesteps = timesteps[None].to(sample.device)
Patrick von Platen's avatar
Patrick von Platen committed
789

790
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
791
        timesteps = timesteps.expand(sample.shape[0])
792

Patrick von Platen's avatar
Patrick von Platen committed
793
        t_emb = self.time_proj(timesteps)
794

795
        # `Timesteps` does not contain any weights and will always return f32 tensors
796
797
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
798
        t_emb = t_emb.to(dtype=sample.dtype)
799
800

        emb = self.time_embedding(t_emb, timestep_cond)
801
        aug_emb = None
Patrick von Platen's avatar
Patrick von Platen committed
802

Will Berman's avatar
Will Berman committed
803
        if self.class_embedding is not None:
804
805
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")
Will Berman's avatar
Will Berman committed
806
807
808
809

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

810
811
812
813
                # `Timesteps` does not contain any weights and will always return f32 tensors
                # there might be better ways to encapsulate this.
                class_labels = class_labels.to(dtype=sample.dtype)

814
            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
Sanchit Gandhi's avatar
Sanchit Gandhi committed
815
816
817
818
819

            if self.config.class_embeddings_concat:
                emb = torch.cat([emb, class_emb], dim=-1)
            else:
                emb = emb + class_emb
820

Patrick von Platen's avatar
Patrick von Platen committed
821
822
        if self.config.addition_embed_type == "text":
            aug_emb = self.add_embedding(encoder_hidden_states)
YiYi Xu's avatar
YiYi Xu committed
823
        elif self.config.addition_embed_type == "text_image":
YiYi Xu's avatar
YiYi Xu committed
824
            # Kandinsky 2.1 - style
YiYi Xu's avatar
YiYi Xu committed
825
826
827
828
829
830
831
832
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
                )

            image_embs = added_cond_kwargs.get("image_embeds")
            text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
            aug_emb = self.add_embedding(text_embs, image_embs)
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
        elif self.config.addition_embed_type == "text_time":
            if "text_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
                )
            text_embeds = added_cond_kwargs.get("text_embeds")
            if "time_ids" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
                )
            time_ids = added_cond_kwargs.get("time_ids")
            time_embeds = self.add_time_proj(time_ids.flatten())
            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
            add_embeds = add_embeds.to(emb.dtype)
            aug_emb = self.add_embedding(add_embeds)
YiYi Xu's avatar
YiYi Xu committed
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
        elif self.config.addition_embed_type == "image":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
                )
            image_embs = added_cond_kwargs.get("image_embeds")
            aug_emb = self.add_embedding(image_embs)
        elif self.config.addition_embed_type == "image_hint":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
                )
            image_embs = added_cond_kwargs.get("image_embeds")
            hint = added_cond_kwargs.get("hint")
            aug_emb, hint = self.add_embedding(image_embs, hint)
            sample = torch.cat([sample, hint], dim=1)
868
869

        emb = emb + aug_emb if aug_emb is not None else emb
Patrick von Platen's avatar
Patrick von Platen committed
870

871
872
873
        if self.time_embed_act is not None:
            emb = self.time_embed_act(emb)

YiYi Xu's avatar
YiYi Xu committed
874
        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
William Berman's avatar
William Berman committed
875
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
YiYi Xu's avatar
YiYi Xu committed
876
877
878
879
880
881
882
883
884
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
            # Kadinsky 2.1 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )

            image_embeds = added_cond_kwargs.get("image_embeds")
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
YiYi Xu's avatar
YiYi Xu committed
885
886
887
888
889
890
891
892
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )
            image_embeds = added_cond_kwargs.get("image_embeds")
            encoder_hidden_states = self.encoder_hid_proj(image_embeds)
Patrick von Platen's avatar
Patrick von Platen committed
893
894
895
896
897
898
        # 2. pre-process
        sample = self.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
899
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
900
                sample, res_samples = downsample_block(
901
902
903
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
Will Berman's avatar
Will Berman committed
904
                    attention_mask=attention_mask,
905
                    cross_attention_kwargs=cross_attention_kwargs,
906
                    encoder_attention_mask=encoder_attention_mask,
Patrick von Platen's avatar
Patrick von Platen committed
907
908
909
910
911
912
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

913
914
915
916
917
918
        if down_block_additional_residuals is not None:
            new_down_block_res_samples = ()

            for down_block_res_sample, down_block_additional_residual in zip(
                down_block_res_samples, down_block_additional_residuals
            ):
919
                down_block_res_sample = down_block_res_sample + down_block_additional_residual
920
                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
921
922
923

            down_block_res_samples = new_down_block_res_samples

Patrick von Platen's avatar
Patrick von Platen committed
924
        # 4. mid
925
926
927
928
929
930
931
        if self.mid_block is not None:
            sample = self.mid_block(
                sample,
                emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                cross_attention_kwargs=cross_attention_kwargs,
932
                encoder_attention_mask=encoder_attention_mask,
933
            )
Patrick von Platen's avatar
Patrick von Platen committed
934

935
        if mid_block_additional_residual is not None:
936
            sample = sample + mid_block_additional_residual
937

Patrick von Platen's avatar
Patrick von Platen committed
938
        # 5. up
939
940
941
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

Patrick von Platen's avatar
Patrick von Platen committed
942
943
944
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

945
946
947
948
949
            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

950
            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
951
952
953
954
955
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
956
                    cross_attention_kwargs=cross_attention_kwargs,
957
                    upsample_size=upsample_size,
Will Berman's avatar
Will Berman committed
958
                    attention_mask=attention_mask,
959
                    encoder_attention_mask=encoder_attention_mask,
Patrick von Platen's avatar
Patrick von Platen committed
960
961
                )
            else:
962
963
964
                sample = upsample_block(
                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
                )
965

Patrick von Platen's avatar
Patrick von Platen committed
966
        # 6. post-process
967
968
969
        if self.conv_norm_out:
            sample = self.conv_norm_out(sample)
            sample = self.conv_act(sample)
Patrick von Platen's avatar
Patrick von Platen committed
970
971
        sample = self.conv_out(sample)

972
973
        if not return_dict:
            return (sample,)
Patrick von Platen's avatar
Patrick von Platen committed
974

975
        return UNet2DConditionOutput(sample=sample)