"vscode:/vscode.git/clone" did not exist on "e9ae678699f20eac30ad60be539838ceb2ac248b"
unet_2d_condition.py 50.1 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from dataclasses import dataclass
15
from typing import Any, Dict, List, Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
16
17
18

import torch
import torch.nn as nn
19
import torch.utils.checkpoint
Patrick von Platen's avatar
Patrick von Platen committed
20
21

from ..configuration_utils import ConfigMixin, register_to_config
22
from ..loaders import UNet2DConditionLoadersMixin
23
from ..utils import BaseOutput, logging
24
from .activations import get_activation
25
from .attention_processor import AttentionProcessor, AttnProcessor
YiYi Xu's avatar
YiYi Xu committed
26
27
from .embeddings import (
    GaussianFourierProjection,
YiYi Xu's avatar
YiYi Xu committed
28
29
30
    ImageHintTimeEmbedding,
    ImageProjection,
    ImageTimeEmbedding,
YiYi Xu's avatar
YiYi Xu committed
31
32
33
34
35
36
    TextImageProjection,
    TextImageTimeEmbedding,
    TextTimeEmbedding,
    TimestepEmbedding,
    Timesteps,
)
37
from .modeling_utils import ModelMixin
38
from .unet_2d_blocks import (
39
    UNetMidBlock2DCrossAttn,
Will Berman's avatar
Will Berman committed
40
    UNetMidBlock2DSimpleCrossAttn,
41
42
43
    get_down_block,
    get_up_block,
)
Patrick von Platen's avatar
Patrick von Platen committed
44
45


46
47
48
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


49
50
51
@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
Steven Liu's avatar
Steven Liu committed
52
53
    The output of [`UNet2DConditionModel`].

54
55
    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Steven Liu's avatar
Steven Liu committed
56
            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
57
58
    """

59
    sample: torch.FloatTensor = None
60
61


62
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
Kashif Rasul's avatar
Kashif Rasul committed
63
    r"""
Steven Liu's avatar
Steven Liu committed
64
65
    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
    shaped output.
Kashif Rasul's avatar
Kashif Rasul committed
66

Steven Liu's avatar
Steven Liu committed
67
68
    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).
Kashif Rasul's avatar
Kashif Rasul committed
69
70

    Parameters:
71
72
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
Steven Liu's avatar
Steven Liu committed
73
74
        in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
Kashif Rasul's avatar
Kashif Rasul committed
75
        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
Suraj Patil's avatar
Suraj Patil committed
76
        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Kashif Rasul's avatar
Kashif Rasul committed
77
78
79
80
            Whether to flip the sin to cos in the time embedding.
        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            The tuple of downsample blocks to use.
Will Berman's avatar
Will Berman committed
81
        mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Steven Liu's avatar
Steven Liu committed
82
83
84
            Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
Kashif Rasul's avatar
Kashif Rasul committed
85
            The tuple of upsample blocks to use.
86
87
88
        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
            Whether to include self-attention in the basic transformer blocks, see
            [`~models.attention.BasicTransformerBlock`].
Kashif Rasul's avatar
Kashif Rasul committed
89
90
91
92
93
94
95
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
Steven Liu's avatar
Steven Liu committed
96
            If `None`, normalization and activation layers is skipped in post-processing.
Kashif Rasul's avatar
Kashif Rasul committed
97
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
Sanchit Gandhi's avatar
Sanchit Gandhi committed
98
99
        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
            The dimension of the cross attention features.
100
101
102
103
104
        transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
        encoder_hid_dim (`int`, *optional*, defaults to None):
YiYi Xu's avatar
YiYi Xu committed
105
106
            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
            dimension to `cross_attention_dim`.
Steven Liu's avatar
Steven Liu committed
107
108
        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
YiYi Xu's avatar
YiYi Xu committed
109
            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
Kashif Rasul's avatar
Kashif Rasul committed
110
        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
111
112
        num_attention_heads (`int`, *optional*):
            The number of attention heads. If not defined, defaults to `attention_head_dim`
Will Berman's avatar
Will Berman committed
113
        resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
Steven Liu's avatar
Steven Liu committed
114
115
            for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
        class_embed_type (`str`, *optional*, defaults to `None`):
116
            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
117
            `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
Steven Liu's avatar
Steven Liu committed
118
        addition_embed_type (`str`, *optional*, defaults to `None`):
Patrick von Platen's avatar
Patrick von Platen committed
119
120
            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
            "text". "text" will use the `TextTimeEmbedding` layer.
121
122
        addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
            Dimension for the timestep embeddings.
Steven Liu's avatar
Steven Liu committed
123
        num_class_embeds (`int`, *optional*, defaults to `None`):
124
125
            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
            class conditioning with `class_embed_type` equal to `None`.
Steven Liu's avatar
Steven Liu committed
126
        time_embedding_type (`str`, *optional*, defaults to `positional`):
127
            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
Steven Liu's avatar
Steven Liu committed
128
        time_embedding_dim (`int`, *optional*, defaults to `None`):
Patrick von Platen's avatar
Patrick von Platen committed
129
            An optional override for the dimension of the projected time embedding.
Steven Liu's avatar
Steven Liu committed
130
131
132
133
        time_embedding_act_fn (`str`, *optional*, defaults to `None`):
            Optional activation function to use only once on the time embeddings before they are passed to the rest of
            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
        timestep_post_act (`str`, *optional*, defaults to `None`):
134
            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
Steven Liu's avatar
Steven Liu committed
135
136
        time_cond_proj_dim (`int`, *optional*, defaults to `None`):
            The dimension of `cond_proj` layer in the timestep embedding.
137
        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
Will Berman's avatar
Will Berman committed
138
139
        conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
        projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
Steven Liu's avatar
Steven Liu committed
140
            `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
Sanchit Gandhi's avatar
Sanchit Gandhi committed
141
        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
142
143
144
            embeddings with the class embeddings.
        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
Steven Liu's avatar
Steven Liu committed
145
146
147
            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
            `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
            otherwise.
Kashif Rasul's avatar
Kashif Rasul committed
148
149
    """

150
151
    _supports_gradient_checkpointing = True

Patrick von Platen's avatar
Patrick von Platen committed
152
153
154
    @register_to_config
    def __init__(
        self,
Sid Sahai's avatar
Sid Sahai committed
155
156
157
158
159
160
161
162
163
164
165
166
        sample_size: Optional[int] = None,
        in_channels: int = 4,
        out_channels: int = 4,
        center_input_sample: bool = False,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        down_block_types: Tuple[str] = (
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
167
        mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
Sid Sahai's avatar
Sid Sahai committed
168
        up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
169
        only_cross_attention: Union[bool, Tuple[bool]] = False,
Sid Sahai's avatar
Sid Sahai committed
170
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
171
        layers_per_block: Union[int, Tuple[int]] = 2,
Sid Sahai's avatar
Sid Sahai committed
172
173
174
        downsample_padding: int = 1,
        mid_block_scale_factor: float = 1,
        act_fn: str = "silu",
175
        norm_num_groups: Optional[int] = 32,
Sid Sahai's avatar
Sid Sahai committed
176
        norm_eps: float = 1e-5,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
177
        cross_attention_dim: Union[int, Tuple[int]] = 1280,
178
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
William Berman's avatar
William Berman committed
179
        encoder_hid_dim: Optional[int] = None,
YiYi Xu's avatar
YiYi Xu committed
180
        encoder_hid_dim_type: Optional[str] = None,
Suraj Patil's avatar
Suraj Patil committed
181
        attention_head_dim: Union[int, Tuple[int]] = 8,
182
        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
183
        dual_cross_attention: bool = False,
Suraj Patil's avatar
Suraj Patil committed
184
        use_linear_projection: bool = False,
Will Berman's avatar
Will Berman committed
185
        class_embed_type: Optional[str] = None,
Patrick von Platen's avatar
Patrick von Platen committed
186
        addition_embed_type: Optional[str] = None,
187
        addition_time_embed_dim: Optional[int] = None,
188
        num_class_embeds: Optional[int] = None,
189
        upcast_attention: bool = False,
Will Berman's avatar
Will Berman committed
190
        resnet_time_scale_shift: str = "default",
191
192
        resnet_skip_time_act: bool = False,
        resnet_out_scale_factor: int = 1.0,
193
        time_embedding_type: str = "positional",
Patrick von Platen's avatar
Patrick von Platen committed
194
        time_embedding_dim: Optional[int] = None,
195
        time_embedding_act_fn: Optional[str] = None,
196
197
198
199
        timestep_post_act: Optional[str] = None,
        time_cond_proj_dim: Optional[int] = None,
        conv_in_kernel: int = 3,
        conv_out_kernel: int = 3,
Will Berman's avatar
Will Berman committed
200
        projection_class_embeddings_input_dim: Optional[int] = None,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
201
        class_embeddings_concat: bool = False,
202
        mid_block_only_cross_attention: Optional[bool] = None,
203
        cross_attention_norm: Optional[str] = None,
Patrick von Platen's avatar
Patrick von Platen committed
204
        addition_embed_type_num_heads=64,
Patrick von Platen's avatar
Patrick von Platen committed
205
206
207
208
209
    ):
        super().__init__()

        self.sample_size = sample_size

210
211
212
213
214
        if num_attention_heads is not None:
            raise ValueError(
                "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
            )

215
216
217
218
219
220
221
222
        # If `num_attention_heads` is not defined (which is the case for most models)
        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
        # The reason for this behavior is to correct for incorrectly named variables that were introduced
        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
        # which is why we correct for the naming here.
        num_attention_heads = num_attention_heads or attention_head_dim

Will Berman's avatar
Will Berman committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        # Check inputs
        if len(down_block_types) != len(up_block_types):
            raise ValueError(
                f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
            )

        if len(block_out_channels) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
            )

239
240
241
242
243
        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
            )

Will Berman's avatar
Will Berman committed
244
245
246
247
248
        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
            )

Sanchit Gandhi's avatar
Sanchit Gandhi committed
249
250
251
252
253
        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
            )

254
255
256
257
258
        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
            )

Patrick von Platen's avatar
Patrick von Platen committed
259
        # input
260
261
262
263
        conv_in_padding = (conv_in_kernel - 1) // 2
        self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
        )
Patrick von Platen's avatar
Patrick von Platen committed
264
265

        # time
266
        if time_embedding_type == "fourier":
Patrick von Platen's avatar
Patrick von Platen committed
267
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
268
269
270
271
272
273
274
            if time_embed_dim % 2 != 0:
                raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
            self.time_proj = GaussianFourierProjection(
                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
            )
            timestep_input_dim = time_embed_dim
        elif time_embedding_type == "positional":
Patrick von Platen's avatar
Patrick von Platen committed
275
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
276
277
278
279
280

            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
            timestep_input_dim = block_out_channels[0]
        else:
            raise ValueError(
Alexander Pivovarov's avatar
Alexander Pivovarov committed
281
                f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
282
            )
Patrick von Platen's avatar
Patrick von Platen committed
283

284
285
286
287
288
289
290
        self.time_embedding = TimestepEmbedding(
            timestep_input_dim,
            time_embed_dim,
            act_fn=act_fn,
            post_act_fn=timestep_post_act,
            cond_proj_dim=time_cond_proj_dim,
        )
Patrick von Platen's avatar
Patrick von Platen committed
291

YiYi Xu's avatar
YiYi Xu committed
292
293
        if encoder_hid_dim_type is None and encoder_hid_dim is not None:
            encoder_hid_dim_type = "text_proj"
294
            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
YiYi Xu's avatar
YiYi Xu committed
295
296
297
298
299
300
301
302
            logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")

        if encoder_hid_dim is None and encoder_hid_dim_type is not None:
            raise ValueError(
                f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
            )

        if encoder_hid_dim_type == "text_proj":
William Berman's avatar
William Berman committed
303
            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
YiYi Xu's avatar
YiYi Xu committed
304
305
306
307
308
309
310
311
312
        elif encoder_hid_dim_type == "text_image_proj":
            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
            # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
            self.encoder_hid_proj = TextImageProjection(
                text_embed_dim=encoder_hid_dim,
                image_embed_dim=cross_attention_dim,
                cross_attention_dim=cross_attention_dim,
            )
YiYi Xu's avatar
YiYi Xu committed
313
314
315
316
317
318
        elif encoder_hid_dim_type == "image_proj":
            # Kandinsky 2.2
            self.encoder_hid_proj = ImageProjection(
                image_embed_dim=encoder_hid_dim,
                cross_attention_dim=cross_attention_dim,
            )
YiYi Xu's avatar
YiYi Xu committed
319
320
321
322
        elif encoder_hid_dim_type is not None:
            raise ValueError(
                f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
            )
William Berman's avatar
William Berman committed
323
324
325
        else:
            self.encoder_hid_proj = None

326
        # class embedding
Will Berman's avatar
Will Berman committed
327
        if class_embed_type is None and num_class_embeds is not None:
328
            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
Will Berman's avatar
Will Berman committed
329
        elif class_embed_type == "timestep":
330
            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
Will Berman's avatar
Will Berman committed
331
332
        elif class_embed_type == "identity":
            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
Will Berman's avatar
Will Berman committed
333
334
335
336
337
338
339
340
341
342
343
344
345
        elif class_embed_type == "projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
                )
            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
            # 2. it projects from an arbitrary input dimension.
            #
            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
Sanchit Gandhi's avatar
Sanchit Gandhi committed
346
347
348
349
350
351
        elif class_embed_type == "simple_projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
                )
            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
Will Berman's avatar
Will Berman committed
352
353
        else:
            self.class_embedding = None
354

Patrick von Platen's avatar
Patrick von Platen committed
355
356
357
358
359
360
361
362
363
        if addition_embed_type == "text":
            if encoder_hid_dim is not None:
                text_time_embedding_from_dim = encoder_hid_dim
            else:
                text_time_embedding_from_dim = cross_attention_dim

            self.add_embedding = TextTimeEmbedding(
                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
            )
YiYi Xu's avatar
YiYi Xu committed
364
365
366
367
368
369
370
        elif addition_embed_type == "text_image":
            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
            # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
            self.add_embedding = TextImageTimeEmbedding(
                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
            )
371
372
373
        elif addition_embed_type == "text_time":
            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
YiYi Xu's avatar
YiYi Xu committed
374
375
376
377
378
379
        elif addition_embed_type == "image":
            # Kandinsky 2.2
            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
        elif addition_embed_type == "image_hint":
            # Kandinsky 2.2 ControlNet
            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
Patrick von Platen's avatar
Patrick von Platen committed
380
        elif addition_embed_type is not None:
YiYi Xu's avatar
YiYi Xu committed
381
            raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
Patrick von Platen's avatar
Patrick von Platen committed
382

383
384
385
        if time_embedding_act_fn is None:
            self.time_embed_act = None
        else:
386
            self.time_embed_act = get_activation(time_embedding_act_fn)
387

Patrick von Platen's avatar
Patrick von Platen committed
388
389
390
        self.down_blocks = nn.ModuleList([])
        self.up_blocks = nn.ModuleList([])

391
        if isinstance(only_cross_attention, bool):
392
393
394
            if mid_block_only_cross_attention is None:
                mid_block_only_cross_attention = only_cross_attention

395
396
            only_cross_attention = [only_cross_attention] * len(down_block_types)

397
398
399
        if mid_block_only_cross_attention is None:
            mid_block_only_cross_attention = False

400
401
402
        if isinstance(num_attention_heads, int):
            num_attention_heads = (num_attention_heads,) * len(down_block_types)

Suraj Patil's avatar
Suraj Patil committed
403
404
405
        if isinstance(attention_head_dim, int):
            attention_head_dim = (attention_head_dim,) * len(down_block_types)

Sanchit Gandhi's avatar
Sanchit Gandhi committed
406
407
408
        if isinstance(cross_attention_dim, int):
            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)

409
410
411
        if isinstance(layers_per_block, int):
            layers_per_block = [layers_per_block] * len(down_block_types)

412
413
414
        if isinstance(transformer_layers_per_block, int):
            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)

Sanchit Gandhi's avatar
Sanchit Gandhi committed
415
416
417
418
419
420
421
422
        if class_embeddings_concat:
            # The time embeddings are concatenated with the class embeddings. The dimension of the
            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
            # regular time embeddings
            blocks_time_embed_dim = time_embed_dim * 2
        else:
            blocks_time_embed_dim = time_embed_dim

Patrick von Platen's avatar
Patrick von Platen committed
423
424
425
426
427
428
429
430
431
        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
432
                num_layers=layers_per_block[i],
433
                transformer_layers_per_block=transformer_layers_per_block[i],
Patrick von Platen's avatar
Patrick von Platen committed
434
435
                in_channels=input_channel,
                out_channels=output_channel,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
436
                temb_channels=blocks_time_embed_dim,
Patrick von Platen's avatar
Patrick von Platen committed
437
438
439
                add_downsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
440
                resnet_groups=norm_num_groups,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
441
                cross_attention_dim=cross_attention_dim[i],
442
                num_attention_heads=num_attention_heads[i],
Patrick von Platen's avatar
Patrick von Platen committed
443
                downsample_padding=downsample_padding,
444
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
445
                use_linear_projection=use_linear_projection,
446
                only_cross_attention=only_cross_attention[i],
447
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
448
                resnet_time_scale_shift=resnet_time_scale_shift,
449
450
                resnet_skip_time_act=resnet_skip_time_act,
                resnet_out_scale_factor=resnet_out_scale_factor,
451
                cross_attention_norm=cross_attention_norm,
452
                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
Patrick von Platen's avatar
Patrick von Platen committed
453
454
455
456
            )
            self.down_blocks.append(down_block)

        # mid
Will Berman's avatar
Will Berman committed
457
458
        if mid_block_type == "UNetMidBlock2DCrossAttn":
            self.mid_block = UNetMidBlock2DCrossAttn(
459
                transformer_layers_per_block=transformer_layers_per_block[-1],
Will Berman's avatar
Will Berman committed
460
                in_channels=block_out_channels[-1],
Sanchit Gandhi's avatar
Sanchit Gandhi committed
461
                temb_channels=blocks_time_embed_dim,
Will Berman's avatar
Will Berman committed
462
463
464
465
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
                resnet_time_scale_shift=resnet_time_scale_shift,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
466
                cross_attention_dim=cross_attention_dim[-1],
467
                num_attention_heads=num_attention_heads[-1],
Will Berman's avatar
Will Berman committed
468
469
470
471
472
473
474
475
                resnet_groups=norm_num_groups,
                dual_cross_attention=dual_cross_attention,
                use_linear_projection=use_linear_projection,
                upcast_attention=upcast_attention,
            )
        elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
            self.mid_block = UNetMidBlock2DSimpleCrossAttn(
                in_channels=block_out_channels[-1],
Sanchit Gandhi's avatar
Sanchit Gandhi committed
476
                temb_channels=blocks_time_embed_dim,
Will Berman's avatar
Will Berman committed
477
478
479
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
480
                cross_attention_dim=cross_attention_dim[-1],
481
                attention_head_dim=attention_head_dim[-1],
Will Berman's avatar
Will Berman committed
482
483
                resnet_groups=norm_num_groups,
                resnet_time_scale_shift=resnet_time_scale_shift,
484
                skip_time_act=resnet_skip_time_act,
485
                only_cross_attention=mid_block_only_cross_attention,
486
                cross_attention_norm=cross_attention_norm,
Will Berman's avatar
Will Berman committed
487
            )
488
489
        elif mid_block_type is None:
            self.mid_block = None
Will Berman's avatar
Will Berman committed
490
491
        else:
            raise ValueError(f"unknown mid_block_type : {mid_block_type}")
Patrick von Platen's avatar
Patrick von Platen committed
492

493
494
495
        # count how many layers upsample the images
        self.num_upsamplers = 0

Patrick von Platen's avatar
Patrick von Platen committed
496
497
        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
498
        reversed_num_attention_heads = list(reversed(num_attention_heads))
499
        reversed_layers_per_block = list(reversed(layers_per_block))
Sanchit Gandhi's avatar
Sanchit Gandhi committed
500
        reversed_cross_attention_dim = list(reversed(cross_attention_dim))
501
        reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
502
        only_cross_attention = list(reversed(only_cross_attention))
503

Patrick von Platen's avatar
Patrick von Platen committed
504
505
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
506
507
            is_final_block = i == len(block_out_channels) - 1

Patrick von Platen's avatar
Patrick von Platen committed
508
509
510
511
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

512
513
514
515
516
517
            # add upsample block for all BUT final layer
            if not is_final_block:
                add_upsample = True
                self.num_upsamplers += 1
            else:
                add_upsample = False
Patrick von Platen's avatar
Patrick von Platen committed
518
519
520

            up_block = get_up_block(
                up_block_type,
521
                num_layers=reversed_layers_per_block[i] + 1,
522
                transformer_layers_per_block=reversed_transformer_layers_per_block[i],
Patrick von Platen's avatar
Patrick von Platen committed
523
524
525
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
526
                temb_channels=blocks_time_embed_dim,
527
                add_upsample=add_upsample,
Patrick von Platen's avatar
Patrick von Platen committed
528
529
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
530
                resnet_groups=norm_num_groups,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
531
                cross_attention_dim=reversed_cross_attention_dim[i],
532
                num_attention_heads=reversed_num_attention_heads[i],
533
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
534
                use_linear_projection=use_linear_projection,
535
                only_cross_attention=only_cross_attention[i],
536
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
537
                resnet_time_scale_shift=resnet_time_scale_shift,
538
539
                resnet_skip_time_act=resnet_skip_time_act,
                resnet_out_scale_factor=resnet_out_scale_factor,
540
                cross_attention_norm=cross_attention_norm,
541
                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
Patrick von Platen's avatar
Patrick von Platen committed
542
543
544
545
546
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # out
547
548
549
550
        if norm_num_groups is not None:
            self.conv_norm_out = nn.GroupNorm(
                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
            )
551

552
            self.conv_act = get_activation(act_fn)
553

554
555
556
557
558
559
560
561
        else:
            self.conv_norm_out = None
            self.conv_act = None

        conv_out_padding = (conv_out_kernel - 1) // 2
        self.conv_out = nn.Conv2d(
            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
        )
Patrick von Platen's avatar
Patrick von Platen committed
562

563
    @property
Patrick von Platen's avatar
Patrick von Platen committed
564
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
565
566
567
568
569
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
570
        # set recursively
571
572
        processors = {}

Patrick von Platen's avatar
Patrick von Platen committed
573
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
574
575
576
577
578
579
580
581
582
583
584
585
586
            if hasattr(module, "set_processor"):
                processors[f"{name}.processor"] = module.processor

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

Patrick von Platen's avatar
Patrick von Platen committed
587
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
588
        r"""
Steven Liu's avatar
Steven Liu committed
589
590
        Sets the attention processor to use to compute attention.

591
        Parameters:
Steven Liu's avatar
Steven Liu committed
592
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
593
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
Steven Liu's avatar
Steven Liu committed
594
595
596
597
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.
598
599
600
601
602
603
604
605
606
607
608

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
609
            if hasattr(module, "set_processor"):
610
611
612
613
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))
614

615
616
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
617

618
619
        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)
620

621
622
623
624
625
626
    def set_default_attn_processor(self):
        """
        Disables custom attention processors and sets the default attention implementation.
        """
        self.set_attn_processor(AttnProcessor())

627
    def set_attention_slice(self, slice_size):
628
629
        r"""
        Enable sliced attention computation.
630

Steven Liu's avatar
Steven Liu committed
631
632
        When this option is enabled, the attention module splits the input tensor in slices to compute attention in
        several steps. This is useful for saving some memory in exchange for a small decrease in speed.
633

634
635
        Args:
            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
Steven Liu's avatar
Steven Liu committed
636
637
638
639
                When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
                `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
                must be a multiple of `slice_size`.
640
641
642
        """
        sliceable_head_dims = []

Alexander Pivovarov's avatar
Alexander Pivovarov committed
643
        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
644
645
646
647
            if hasattr(module, "set_attention_slice"):
                sliceable_head_dims.append(module.sliceable_head_dim)

            for child in module.children():
Alexander Pivovarov's avatar
Alexander Pivovarov committed
648
                fn_recursive_retrieve_sliceable_dims(child)
649
650
651

        # retrieve number of attention layers
        for module in self.children():
Alexander Pivovarov's avatar
Alexander Pivovarov committed
652
            fn_recursive_retrieve_sliceable_dims(module)
653

Alexander Pivovarov's avatar
Alexander Pivovarov committed
654
        num_sliceable_layers = len(sliceable_head_dims)
655
656
657
658
659
660
661

        if slice_size == "auto":
            # half the attention head size is usually a good trade-off between
            # speed and memory
            slice_size = [dim // 2 for dim in sliceable_head_dims]
        elif slice_size == "max":
            # make smallest slice possible
Alexander Pivovarov's avatar
Alexander Pivovarov committed
662
            slice_size = num_sliceable_layers * [1]
663

Alexander Pivovarov's avatar
Alexander Pivovarov committed
664
        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
665
666
667
668
669
670

        if len(slice_size) != len(sliceable_head_dims):
            raise ValueError(
                f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
                f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
            )
671

672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
        for i in range(len(slice_size)):
            size = slice_size[i]
            dim = sliceable_head_dims[i]
            if size is not None and size > dim:
                raise ValueError(f"size {size} has to be smaller or equal to {dim}.")

        # Recursively walk through all the children.
        # Any children which exposes the set_attention_slice method
        # gets the message
        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
            if hasattr(module, "set_attention_slice"):
                module.set_attention_slice(slice_size.pop())

            for child in module.children():
                fn_recursive_set_attention_slice(child, slice_size)

        reversed_slice_size = list(reversed(slice_size))
        for module in self.children():
            fn_recursive_set_attention_slice(module, reversed_slice_size)
691

692
    def _set_gradient_checkpointing(self, module, value=False):
693
        if hasattr(module, "gradient_checkpointing"):
694
695
            module.gradient_checkpointing = value

Patrick von Platen's avatar
Patrick von Platen committed
696
697
698
699
700
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
701
        class_labels: Optional[torch.Tensor] = None,
702
        timestep_cond: Optional[torch.Tensor] = None,
Will Berman's avatar
Will Berman committed
703
        attention_mask: Optional[torch.Tensor] = None,
704
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
YiYi Xu's avatar
YiYi Xu committed
705
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
706
707
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
708
        encoder_attention_mask: Optional[torch.Tensor] = None,
709
710
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
711
        r"""
Steven Liu's avatar
Steven Liu committed
712
713
        The [`UNet2DConditionModel`] forward method.

Kashif Rasul's avatar
Kashif Rasul committed
714
        Args:
Steven Liu's avatar
Steven Liu committed
715
716
717
718
719
            sample (`torch.FloatTensor`):
                The noisy input tensor with the following shape `(batch, channel, height, width)`.
            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
            encoder_hidden_states (`torch.FloatTensor`):
                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
720
            encoder_attention_mask (`torch.Tensor`):
Steven Liu's avatar
Steven Liu committed
721
722
723
                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
                which adds large negative values to the attention scores corresponding to "discard" tokens.
Kashif Rasul's avatar
Kashif Rasul committed
724
            return_dict (`bool`, *optional*, defaults to `True`):
Steven Liu's avatar
Steven Liu committed
725
726
                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
                tuple.
727
            cross_attention_kwargs (`dict`, *optional*):
Steven Liu's avatar
Steven Liu committed
728
                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
729
730
731
            added_cond_kwargs: (`dict`, *optional*):
                A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
                are passed along to the UNet blocks.
Kashif Rasul's avatar
Kashif Rasul committed
732
733
734

        Returns:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
Steven Liu's avatar
Steven Liu committed
735
736
                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
                a `tuple` is returned where the first element is the sample tensor.
Kashif Rasul's avatar
Kashif Rasul committed
737
        """
738
        # By default samples have to be AT least a multiple of the overall upsampling factor.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
739
        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
740
741
742
743
744
745
746
747
748
749
750
751
        # However, the upsampling interpolation output size can be forced to fit any upsampling size
        # on the fly if necessary.
        default_overall_up_factor = 2**self.num_upsamplers

        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
        forward_upsample_size = False
        upsample_size = None

        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
            logger.info("Forward upsample size to force interpolation output size.")
            forward_upsample_size = True

752
753
754
755
756
757
758
759
        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
        # expects mask of shape:
        #   [batch, key_tokens]
        # adds singleton query_tokens dimension:
        #   [batch,                    1, key_tokens]
        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
Will Berman's avatar
Will Berman committed
760
        if attention_mask is not None:
761
762
763
764
            # assume that mask is expressed as:
            #   (1 = keep,      0 = discard)
            # convert mask into a bias that can be added to attention scores:
            #       (keep = +0,     discard = -10000.0)
Will Berman's avatar
Will Berman committed
765
766
767
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
            attention_mask = attention_mask.unsqueeze(1)

768
769
770
771
772
        # convert encoder_attention_mask to a bias the same way we do for attention_mask
        if encoder_attention_mask is not None:
            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

Patrick von Platen's avatar
Patrick von Platen committed
773
774
775
776
777
778
779
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
780
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
781
782
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
Patrick von Platen's avatar
Patrick von Platen committed
783
            if isinstance(timestep, float):
784
785
786
787
788
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
789
            timesteps = timesteps[None].to(sample.device)
Patrick von Platen's avatar
Patrick von Platen committed
790

791
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
792
        timesteps = timesteps.expand(sample.shape[0])
793

Patrick von Platen's avatar
Patrick von Platen committed
794
        t_emb = self.time_proj(timesteps)
795

796
        # `Timesteps` does not contain any weights and will always return f32 tensors
797
798
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
799
        t_emb = t_emb.to(dtype=sample.dtype)
800
801

        emb = self.time_embedding(t_emb, timestep_cond)
802
        aug_emb = None
Patrick von Platen's avatar
Patrick von Platen committed
803

Will Berman's avatar
Will Berman committed
804
        if self.class_embedding is not None:
805
806
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")
Will Berman's avatar
Will Berman committed
807
808
809
810

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

811
812
813
814
                # `Timesteps` does not contain any weights and will always return f32 tensors
                # there might be better ways to encapsulate this.
                class_labels = class_labels.to(dtype=sample.dtype)

815
            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
Sanchit Gandhi's avatar
Sanchit Gandhi committed
816
817
818
819
820

            if self.config.class_embeddings_concat:
                emb = torch.cat([emb, class_emb], dim=-1)
            else:
                emb = emb + class_emb
821

Patrick von Platen's avatar
Patrick von Platen committed
822
823
        if self.config.addition_embed_type == "text":
            aug_emb = self.add_embedding(encoder_hidden_states)
YiYi Xu's avatar
YiYi Xu committed
824
        elif self.config.addition_embed_type == "text_image":
YiYi Xu's avatar
YiYi Xu committed
825
            # Kandinsky 2.1 - style
YiYi Xu's avatar
YiYi Xu committed
826
827
828
829
830
831
832
833
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
                )

            image_embs = added_cond_kwargs.get("image_embeds")
            text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
            aug_emb = self.add_embedding(text_embs, image_embs)
834
        elif self.config.addition_embed_type == "text_time":
835
            # SDXL - style
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
            if "text_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
                )
            text_embeds = added_cond_kwargs.get("text_embeds")
            if "time_ids" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
                )
            time_ids = added_cond_kwargs.get("time_ids")
            time_embeds = self.add_time_proj(time_ids.flatten())
            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
            add_embeds = add_embeds.to(emb.dtype)
            aug_emb = self.add_embedding(add_embeds)
YiYi Xu's avatar
YiYi Xu committed
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
        elif self.config.addition_embed_type == "image":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
                )
            image_embs = added_cond_kwargs.get("image_embeds")
            aug_emb = self.add_embedding(image_embs)
        elif self.config.addition_embed_type == "image_hint":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
                )
            image_embs = added_cond_kwargs.get("image_embeds")
            hint = added_cond_kwargs.get("hint")
            aug_emb, hint = self.add_embedding(image_embs, hint)
            sample = torch.cat([sample, hint], dim=1)
870
871

        emb = emb + aug_emb if aug_emb is not None else emb
Patrick von Platen's avatar
Patrick von Platen committed
872

873
874
875
        if self.time_embed_act is not None:
            emb = self.time_embed_act(emb)

YiYi Xu's avatar
YiYi Xu committed
876
        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
William Berman's avatar
William Berman committed
877
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
YiYi Xu's avatar
YiYi Xu committed
878
879
880
881
882
883
884
885
886
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
            # Kadinsky 2.1 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )

            image_embeds = added_cond_kwargs.get("image_embeds")
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
YiYi Xu's avatar
YiYi Xu committed
887
888
889
890
891
892
893
894
        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
            # Kandinsky 2.2 - style
            if "image_embeds" not in added_cond_kwargs:
                raise ValueError(
                    f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                )
            image_embeds = added_cond_kwargs.get("image_embeds")
            encoder_hidden_states = self.encoder_hid_proj(image_embeds)
Patrick von Platen's avatar
Patrick von Platen committed
895
896
897
898
        # 2. pre-process
        sample = self.conv_in(sample)

        # 3. down
Will Berman's avatar
Will Berman committed
899
900
901
902

        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
        is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None

Patrick von Platen's avatar
Patrick von Platen committed
903
904
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
905
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
Will Berman's avatar
Will Berman committed
906
907
908
909
910
                # For t2i-adapter CrossAttnDownBlock2D
                additional_residuals = {}
                if is_adapter and len(down_block_additional_residuals) > 0:
                    additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)

Patrick von Platen's avatar
Patrick von Platen committed
911
                sample, res_samples = downsample_block(
912
913
914
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
Will Berman's avatar
Will Berman committed
915
                    attention_mask=attention_mask,
916
                    cross_attention_kwargs=cross_attention_kwargs,
917
                    encoder_attention_mask=encoder_attention_mask,
Will Berman's avatar
Will Berman committed
918
                    **additional_residuals,
Patrick von Platen's avatar
Patrick von Platen committed
919
920
921
922
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

Will Berman's avatar
Will Berman committed
923
924
925
                if is_adapter and len(down_block_additional_residuals) > 0:
                    sample += down_block_additional_residuals.pop(0)

Patrick von Platen's avatar
Patrick von Platen committed
926
927
            down_block_res_samples += res_samples

Will Berman's avatar
Will Berman committed
928
        if is_controlnet:
929
930
931
932
933
            new_down_block_res_samples = ()

            for down_block_res_sample, down_block_additional_residual in zip(
                down_block_res_samples, down_block_additional_residuals
            ):
934
                down_block_res_sample = down_block_res_sample + down_block_additional_residual
935
                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
936
937
938

            down_block_res_samples = new_down_block_res_samples

Patrick von Platen's avatar
Patrick von Platen committed
939
        # 4. mid
940
941
942
943
944
945
946
        if self.mid_block is not None:
            sample = self.mid_block(
                sample,
                emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                cross_attention_kwargs=cross_attention_kwargs,
947
                encoder_attention_mask=encoder_attention_mask,
948
            )
Patrick von Platen's avatar
Patrick von Platen committed
949

Will Berman's avatar
Will Berman committed
950
        if is_controlnet:
951
            sample = sample + mid_block_additional_residual
952

Patrick von Platen's avatar
Patrick von Platen committed
953
        # 5. up
954
955
956
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

Patrick von Platen's avatar
Patrick von Platen committed
957
958
959
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

960
961
962
963
964
            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

965
            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
966
967
968
969
970
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
971
                    cross_attention_kwargs=cross_attention_kwargs,
972
                    upsample_size=upsample_size,
Will Berman's avatar
Will Berman committed
973
                    attention_mask=attention_mask,
974
                    encoder_attention_mask=encoder_attention_mask,
Patrick von Platen's avatar
Patrick von Platen committed
975
976
                )
            else:
977
978
979
                sample = upsample_block(
                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
                )
980

Patrick von Platen's avatar
Patrick von Platen committed
981
        # 6. post-process
982
983
984
        if self.conv_norm_out:
            sample = self.conv_norm_out(sample)
            sample = self.conv_act(sample)
Patrick von Platen's avatar
Patrick von Platen committed
985
986
        sample = self.conv_out(sample)

987
988
        if not return_dict:
            return (sample,)
Patrick von Platen's avatar
Patrick von Platen committed
989

990
        return UNet2DConditionOutput(sample=sample)