unet_2d_condition.py 34.3 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from dataclasses import dataclass
15
from typing import Any, Dict, List, Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
16
17
18

import torch
import torch.nn as nn
19
import torch.utils.checkpoint
Patrick von Platen's avatar
Patrick von Platen committed
20
21

from ..configuration_utils import ConfigMixin, register_to_config
22
from ..loaders import UNet2DConditionLoadersMixin
23
from ..utils import BaseOutput, deprecate, logging
24
from .attention_processor import AttentionProcessor, AttnProcessor
25
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
26
from .modeling_utils import ModelMixin
27
from .unet_2d_blocks import (
28
29
30
31
    CrossAttnDownBlock2D,
    CrossAttnUpBlock2D,
    DownBlock2D,
    UNetMidBlock2DCrossAttn,
Will Berman's avatar
Will Berman committed
32
    UNetMidBlock2DSimpleCrossAttn,
33
34
35
36
    UpBlock2D,
    get_down_block,
    get_up_block,
)
Patrick von Platen's avatar
Patrick von Platen committed
37
38


39
40
41
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


42
43
44
45
46
47
48
49
50
51
52
@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
    """

    sample: torch.FloatTensor


53
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
Kashif Rasul's avatar
Kashif Rasul committed
54
55
56
57
58
    r"""
    UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
    and returns sample shaped output.

    This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
59
    implements for all the models (such as downloading or saving, etc.)
Kashif Rasul's avatar
Kashif Rasul committed
60
61

    Parameters:
62
63
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
Kashif Rasul's avatar
Kashif Rasul committed
64
65
66
        in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
        out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
Suraj Patil's avatar
Suraj Patil committed
67
        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Kashif Rasul's avatar
Kashif Rasul committed
68
69
70
71
            Whether to flip the sin to cos in the time embedding.
        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            The tuple of downsample blocks to use.
Will Berman's avatar
Will Berman committed
72
        mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
73
74
            The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
            mid block layer if `None`.
Kashif Rasul's avatar
Kashif Rasul committed
75
76
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
            The tuple of upsample blocks to use.
77
78
79
        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
            Whether to include self-attention in the basic transformer blocks, see
            [`~models.attention.BasicTransformerBlock`].
Kashif Rasul's avatar
Kashif Rasul committed
80
81
82
83
84
85
86
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
87
            If `None`, it will skip the normalization and activation layers in post-processing
Kashif Rasul's avatar
Kashif Rasul committed
88
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
Sanchit Gandhi's avatar
Sanchit Gandhi committed
89
90
        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
            The dimension of the cross attention features.
William Berman's avatar
William Berman committed
91
92
        encoder_hid_dim (`int`, *optional*, defaults to None):
            If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`.
Kashif Rasul's avatar
Kashif Rasul committed
93
        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
Will Berman's avatar
Will Berman committed
94
95
        resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
            for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
96
97
        class_embed_type (`str`, *optional*, defaults to None):
            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
98
            `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
99
100
101
        num_class_embeds (`int`, *optional*, defaults to None):
            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
            class conditioning with `class_embed_type` equal to `None`.
102
103
104
105
106
107
108
        time_embedding_type (`str`, *optional*, default to `positional`):
            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
        timestep_post_act (`str, *optional*, default to `None`):
            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
        time_cond_proj_dim (`int`, *optional*, default to `None`):
            The dimension of `cond_proj` layer in timestep embedding.
        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
Will Berman's avatar
Will Berman committed
109
110
111
        conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
        projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
            using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
Sanchit Gandhi's avatar
Sanchit Gandhi committed
112
113
        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
        embeddings with the class embeddings.
Kashif Rasul's avatar
Kashif Rasul committed
114
115
    """

116
117
    _supports_gradient_checkpointing = True

Patrick von Platen's avatar
Patrick von Platen committed
118
119
120
    @register_to_config
    def __init__(
        self,
Sid Sahai's avatar
Sid Sahai committed
121
122
123
124
125
126
127
128
129
130
131
132
        sample_size: Optional[int] = None,
        in_channels: int = 4,
        out_channels: int = 4,
        center_input_sample: bool = False,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        down_block_types: Tuple[str] = (
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
133
        mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
Sid Sahai's avatar
Sid Sahai committed
134
        up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
135
        only_cross_attention: Union[bool, Tuple[bool]] = False,
Sid Sahai's avatar
Sid Sahai committed
136
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
137
        layers_per_block: Union[int, Tuple[int]] = 2,
Sid Sahai's avatar
Sid Sahai committed
138
139
140
        downsample_padding: int = 1,
        mid_block_scale_factor: float = 1,
        act_fn: str = "silu",
141
        norm_num_groups: Optional[int] = 32,
Sid Sahai's avatar
Sid Sahai committed
142
        norm_eps: float = 1e-5,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
143
        cross_attention_dim: Union[int, Tuple[int]] = 1280,
William Berman's avatar
William Berman committed
144
        encoder_hid_dim: Optional[int] = None,
Suraj Patil's avatar
Suraj Patil committed
145
        attention_head_dim: Union[int, Tuple[int]] = 8,
146
        dual_cross_attention: bool = False,
Suraj Patil's avatar
Suraj Patil committed
147
        use_linear_projection: bool = False,
Will Berman's avatar
Will Berman committed
148
        class_embed_type: Optional[str] = None,
149
        num_class_embeds: Optional[int] = None,
150
        upcast_attention: bool = False,
Will Berman's avatar
Will Berman committed
151
        resnet_time_scale_shift: str = "default",
152
153
        resnet_skip_time_act: bool = False,
        resnet_out_scale_factor: int = 1.0,
154
        time_embedding_type: str = "positional",
155
156
157
158
        timestep_post_act: Optional[str] = None,
        time_cond_proj_dim: Optional[int] = None,
        conv_in_kernel: int = 3,
        conv_out_kernel: int = 3,
Will Berman's avatar
Will Berman committed
159
        projection_class_embeddings_input_dim: Optional[int] = None,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
160
        class_embeddings_concat: bool = False,
Patrick von Platen's avatar
Patrick von Platen committed
161
162
163
164
165
    ):
        super().__init__()

        self.sample_size = sample_size

Will Berman's avatar
Will Berman committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        # Check inputs
        if len(down_block_types) != len(up_block_types):
            raise ValueError(
                f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
            )

        if len(block_out_channels) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
            )

        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
            )

Sanchit Gandhi's avatar
Sanchit Gandhi committed
187
188
189
190
191
        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
            )

192
193
194
195
196
        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
            raise ValueError(
                f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
            )

Patrick von Platen's avatar
Patrick von Platen committed
197
        # input
198
199
200
201
        conv_in_padding = (conv_in_kernel - 1) // 2
        self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
        )
Patrick von Platen's avatar
Patrick von Platen committed
202
203

        # time
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        if time_embedding_type == "fourier":
            time_embed_dim = block_out_channels[0] * 2
            if time_embed_dim % 2 != 0:
                raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
            self.time_proj = GaussianFourierProjection(
                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
            )
            timestep_input_dim = time_embed_dim
        elif time_embedding_type == "positional":
            time_embed_dim = block_out_channels[0] * 4

            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
            timestep_input_dim = block_out_channels[0]
        else:
            raise ValueError(
Alexander Pivovarov's avatar
Alexander Pivovarov committed
219
                f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
220
            )
Patrick von Platen's avatar
Patrick von Platen committed
221

222
223
224
225
226
227
228
        self.time_embedding = TimestepEmbedding(
            timestep_input_dim,
            time_embed_dim,
            act_fn=act_fn,
            post_act_fn=timestep_post_act,
            cond_proj_dim=time_cond_proj_dim,
        )
Patrick von Platen's avatar
Patrick von Platen committed
229

William Berman's avatar
William Berman committed
230
231
232
233
234
        if encoder_hid_dim is not None:
            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
        else:
            self.encoder_hid_proj = None

235
        # class embedding
Will Berman's avatar
Will Berman committed
236
        if class_embed_type is None and num_class_embeds is not None:
237
            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
Will Berman's avatar
Will Berman committed
238
239
240
241
        elif class_embed_type == "timestep":
            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
        elif class_embed_type == "identity":
            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
Will Berman's avatar
Will Berman committed
242
243
244
245
246
247
248
249
250
251
252
253
254
        elif class_embed_type == "projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
                )
            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
            # 2. it projects from an arbitrary input dimension.
            #
            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
Sanchit Gandhi's avatar
Sanchit Gandhi committed
255
256
257
258
259
260
        elif class_embed_type == "simple_projection":
            if projection_class_embeddings_input_dim is None:
                raise ValueError(
                    "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
                )
            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
Will Berman's avatar
Will Berman committed
261
262
        else:
            self.class_embedding = None
263

Patrick von Platen's avatar
Patrick von Platen committed
264
265
266
        self.down_blocks = nn.ModuleList([])
        self.up_blocks = nn.ModuleList([])

267
268
269
        if isinstance(only_cross_attention, bool):
            only_cross_attention = [only_cross_attention] * len(down_block_types)

Suraj Patil's avatar
Suraj Patil committed
270
271
272
        if isinstance(attention_head_dim, int):
            attention_head_dim = (attention_head_dim,) * len(down_block_types)

Sanchit Gandhi's avatar
Sanchit Gandhi committed
273
274
275
        if isinstance(cross_attention_dim, int):
            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)

276
277
278
        if isinstance(layers_per_block, int):
            layers_per_block = [layers_per_block] * len(down_block_types)

Sanchit Gandhi's avatar
Sanchit Gandhi committed
279
280
281
282
283
284
285
286
        if class_embeddings_concat:
            # The time embeddings are concatenated with the class embeddings. The dimension of the
            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
            # regular time embeddings
            blocks_time_embed_dim = time_embed_dim * 2
        else:
            blocks_time_embed_dim = time_embed_dim

Patrick von Platen's avatar
Patrick von Platen committed
287
288
289
290
291
292
293
294
295
        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
296
                num_layers=layers_per_block[i],
Patrick von Platen's avatar
Patrick von Platen committed
297
298
                in_channels=input_channel,
                out_channels=output_channel,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
299
                temb_channels=blocks_time_embed_dim,
Patrick von Platen's avatar
Patrick von Platen committed
300
301
302
                add_downsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
303
                resnet_groups=norm_num_groups,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
304
                cross_attention_dim=cross_attention_dim[i],
Suraj Patil's avatar
Suraj Patil committed
305
                attn_num_head_channels=attention_head_dim[i],
Patrick von Platen's avatar
Patrick von Platen committed
306
                downsample_padding=downsample_padding,
307
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
308
                use_linear_projection=use_linear_projection,
309
                only_cross_attention=only_cross_attention[i],
310
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
311
                resnet_time_scale_shift=resnet_time_scale_shift,
312
313
                resnet_skip_time_act=resnet_skip_time_act,
                resnet_out_scale_factor=resnet_out_scale_factor,
Patrick von Platen's avatar
Patrick von Platen committed
314
315
316
317
            )
            self.down_blocks.append(down_block)

        # mid
Will Berman's avatar
Will Berman committed
318
319
320
        if mid_block_type == "UNetMidBlock2DCrossAttn":
            self.mid_block = UNetMidBlock2DCrossAttn(
                in_channels=block_out_channels[-1],
Sanchit Gandhi's avatar
Sanchit Gandhi committed
321
                temb_channels=blocks_time_embed_dim,
Will Berman's avatar
Will Berman committed
322
323
324
325
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
                resnet_time_scale_shift=resnet_time_scale_shift,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
326
                cross_attention_dim=cross_attention_dim[-1],
Will Berman's avatar
Will Berman committed
327
328
329
330
331
332
333
334
335
                attn_num_head_channels=attention_head_dim[-1],
                resnet_groups=norm_num_groups,
                dual_cross_attention=dual_cross_attention,
                use_linear_projection=use_linear_projection,
                upcast_attention=upcast_attention,
            )
        elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
            self.mid_block = UNetMidBlock2DSimpleCrossAttn(
                in_channels=block_out_channels[-1],
Sanchit Gandhi's avatar
Sanchit Gandhi committed
336
                temb_channels=blocks_time_embed_dim,
Will Berman's avatar
Will Berman committed
337
338
339
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
340
                cross_attention_dim=cross_attention_dim[-1],
Will Berman's avatar
Will Berman committed
341
342
343
                attn_num_head_channels=attention_head_dim[-1],
                resnet_groups=norm_num_groups,
                resnet_time_scale_shift=resnet_time_scale_shift,
344
                skip_time_act=resnet_skip_time_act,
Will Berman's avatar
Will Berman committed
345
            )
346
347
        elif mid_block_type is None:
            self.mid_block = None
Will Berman's avatar
Will Berman committed
348
349
        else:
            raise ValueError(f"unknown mid_block_type : {mid_block_type}")
Patrick von Platen's avatar
Patrick von Platen committed
350

351
352
353
        # count how many layers upsample the images
        self.num_upsamplers = 0

Patrick von Platen's avatar
Patrick von Platen committed
354
355
        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
Suraj Patil's avatar
Suraj Patil committed
356
        reversed_attention_head_dim = list(reversed(attention_head_dim))
357
        reversed_layers_per_block = list(reversed(layers_per_block))
Sanchit Gandhi's avatar
Sanchit Gandhi committed
358
        reversed_cross_attention_dim = list(reversed(cross_attention_dim))
359
        only_cross_attention = list(reversed(only_cross_attention))
360

Patrick von Platen's avatar
Patrick von Platen committed
361
362
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
363
364
            is_final_block = i == len(block_out_channels) - 1

Patrick von Platen's avatar
Patrick von Platen committed
365
366
367
368
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

369
370
371
372
373
374
            # add upsample block for all BUT final layer
            if not is_final_block:
                add_upsample = True
                self.num_upsamplers += 1
            else:
                add_upsample = False
Patrick von Platen's avatar
Patrick von Platen committed
375
376
377

            up_block = get_up_block(
                up_block_type,
378
                num_layers=reversed_layers_per_block[i] + 1,
Patrick von Platen's avatar
Patrick von Platen committed
379
380
381
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
382
                temb_channels=blocks_time_embed_dim,
383
                add_upsample=add_upsample,
Patrick von Platen's avatar
Patrick von Platen committed
384
385
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
386
                resnet_groups=norm_num_groups,
Sanchit Gandhi's avatar
Sanchit Gandhi committed
387
                cross_attention_dim=reversed_cross_attention_dim[i],
Suraj Patil's avatar
Suraj Patil committed
388
                attn_num_head_channels=reversed_attention_head_dim[i],
389
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
390
                use_linear_projection=use_linear_projection,
391
                only_cross_attention=only_cross_attention[i],
392
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
393
                resnet_time_scale_shift=resnet_time_scale_shift,
394
395
                resnet_skip_time_act=resnet_skip_time_act,
                resnet_out_scale_factor=resnet_out_scale_factor,
Patrick von Platen's avatar
Patrick von Platen committed
396
397
398
399
400
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # out
401
402
403
404
405
406
407
408
409
410
411
412
413
        if norm_num_groups is not None:
            self.conv_norm_out = nn.GroupNorm(
                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
            )
            self.conv_act = nn.SiLU()
        else:
            self.conv_norm_out = None
            self.conv_act = None

        conv_out_padding = (conv_out_kernel - 1) // 2
        self.conv_out = nn.Conv2d(
            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
        )
Patrick von Platen's avatar
Patrick von Platen committed
414

415
416
417
418
419
420
421
422
423
424
    @property
    def in_channels(self):
        deprecate(
            "in_channels",
            "1.0.0",
            "Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
            standard_warn=False,
        )
        return self.config.in_channels

425
    @property
Patrick von Platen's avatar
Patrick von Platen committed
426
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
427
428
429
430
431
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
432
        # set recursively
433
434
        processors = {}

Patrick von Platen's avatar
Patrick von Platen committed
435
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
436
437
438
439
440
441
442
443
444
445
446
447
448
            if hasattr(module, "set_processor"):
                processors[f"{name}.processor"] = module.processor

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

Patrick von Platen's avatar
Patrick von Platen committed
449
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
450
451
        r"""
        Parameters:
Patrick von Platen's avatar
Patrick von Platen committed
452
            `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
453
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
Patrick von Platen's avatar
Patrick von Platen committed
454
                of **all** `Attention` layers.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
455
            In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
456
457
458
459
460
461
462
463
464
465
466

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
467
            if hasattr(module, "set_processor"):
468
469
470
471
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))
472

473
474
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
475

476
477
        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)
478

479
480
481
482
483
484
    def set_default_attn_processor(self):
        """
        Disables custom attention processors and sets the default attention implementation.
        """
        self.set_attn_processor(AttnProcessor())

485
    def set_attention_slice(self, slice_size):
486
487
        r"""
        Enable sliced attention computation.
488

489
490
        When this option is enabled, the attention module will split the input tensor in slices, to compute attention
        in several steps. This is useful to save some memory in exchange for a small speed decrease.
491

492
493
494
        Args:
            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
                When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
Alexander Pivovarov's avatar
Alexander Pivovarov committed
495
                `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
496
497
498
499
500
                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
                must be a multiple of `slice_size`.
        """
        sliceable_head_dims = []

Alexander Pivovarov's avatar
Alexander Pivovarov committed
501
        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
502
503
504
505
            if hasattr(module, "set_attention_slice"):
                sliceable_head_dims.append(module.sliceable_head_dim)

            for child in module.children():
Alexander Pivovarov's avatar
Alexander Pivovarov committed
506
                fn_recursive_retrieve_sliceable_dims(child)
507
508
509

        # retrieve number of attention layers
        for module in self.children():
Alexander Pivovarov's avatar
Alexander Pivovarov committed
510
            fn_recursive_retrieve_sliceable_dims(module)
511

Alexander Pivovarov's avatar
Alexander Pivovarov committed
512
        num_sliceable_layers = len(sliceable_head_dims)
513
514
515
516
517
518
519

        if slice_size == "auto":
            # half the attention head size is usually a good trade-off between
            # speed and memory
            slice_size = [dim // 2 for dim in sliceable_head_dims]
        elif slice_size == "max":
            # make smallest slice possible
Alexander Pivovarov's avatar
Alexander Pivovarov committed
520
            slice_size = num_sliceable_layers * [1]
521

Alexander Pivovarov's avatar
Alexander Pivovarov committed
522
        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
523
524
525
526
527
528

        if len(slice_size) != len(sliceable_head_dims):
            raise ValueError(
                f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
                f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
            )
529

530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        for i in range(len(slice_size)):
            size = slice_size[i]
            dim = sliceable_head_dims[i]
            if size is not None and size > dim:
                raise ValueError(f"size {size} has to be smaller or equal to {dim}.")

        # Recursively walk through all the children.
        # Any children which exposes the set_attention_slice method
        # gets the message
        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
            if hasattr(module, "set_attention_slice"):
                module.set_attention_slice(slice_size.pop())

            for child in module.children():
                fn_recursive_set_attention_slice(child, slice_size)

        reversed_slice_size = list(reversed(slice_size))
        for module in self.children():
            fn_recursive_set_attention_slice(module, reversed_slice_size)
549

550
551
552
553
    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
            module.gradient_checkpointing = value

Patrick von Platen's avatar
Patrick von Platen committed
554
555
556
557
558
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
559
        class_labels: Optional[torch.Tensor] = None,
560
        timestep_cond: Optional[torch.Tensor] = None,
Will Berman's avatar
Will Berman committed
561
        attention_mask: Optional[torch.Tensor] = None,
562
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
563
564
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        mid_block_additional_residual: Optional[torch.Tensor] = None,
565
566
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
567
        r"""
Kashif Rasul's avatar
Kashif Rasul committed
568
569
        Args:
            sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
570
            timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
571
            encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
Kashif Rasul's avatar
Kashif Rasul committed
572
573
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
574
            cross_attention_kwargs (`dict`, *optional*):
Patrick von Platen's avatar
Patrick von Platen committed
575
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
576
577
                `self.processor` in
                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
Kashif Rasul's avatar
Kashif Rasul committed
578
579
580
581
582
583

        Returns:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
        """
584
        # By default samples have to be AT least a multiple of the overall upsampling factor.
Alexander Pivovarov's avatar
Alexander Pivovarov committed
585
        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
586
587
588
589
590
591
592
593
594
595
596
597
        # However, the upsampling interpolation output size can be forced to fit any upsampling size
        # on the fly if necessary.
        default_overall_up_factor = 2**self.num_upsamplers

        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
        forward_upsample_size = False
        upsample_size = None

        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
            logger.info("Forward upsample size to force interpolation output size.")
            forward_upsample_size = True

Will Berman's avatar
Will Berman committed
598
599
600
601
602
        # prepare attention_mask
        if attention_mask is not None:
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
            attention_mask = attention_mask.unsqueeze(1)

Patrick von Platen's avatar
Patrick von Platen committed
603
604
605
606
607
608
609
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
610
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
611
612
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
Patrick von Platen's avatar
Patrick von Platen committed
613
            if isinstance(timestep, float):
614
615
616
617
618
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
619
            timesteps = timesteps[None].to(sample.device)
Patrick von Platen's avatar
Patrick von Platen committed
620

621
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
622
        timesteps = timesteps.expand(sample.shape[0])
623

Patrick von Platen's avatar
Patrick von Platen committed
624
        t_emb = self.time_proj(timesteps)
625
626
627
628
629

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=self.dtype)
630
631

        emb = self.time_embedding(t_emb, timestep_cond)
Patrick von Platen's avatar
Patrick von Platen committed
632

Will Berman's avatar
Will Berman committed
633
        if self.class_embedding is not None:
634
635
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")
Will Berman's avatar
Will Berman committed
636
637
638
639

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

640
            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
Sanchit Gandhi's avatar
Sanchit Gandhi committed
641
642
643
644
645

            if self.config.class_embeddings_concat:
                emb = torch.cat([emb, class_emb], dim=-1)
            else:
                emb = emb + class_emb
646

William Berman's avatar
William Berman committed
647
648
649
        if self.encoder_hid_proj is not None:
            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)

Patrick von Platen's avatar
Patrick von Platen committed
650
651
652
653
654
655
        # 2. pre-process
        sample = self.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
656
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
657
                sample, res_samples = downsample_block(
658
659
660
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
Will Berman's avatar
Will Berman committed
661
                    attention_mask=attention_mask,
662
                    cross_attention_kwargs=cross_attention_kwargs,
Patrick von Platen's avatar
Patrick von Platen committed
663
664
665
666
667
668
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

669
670
671
672
673
674
        if down_block_additional_residuals is not None:
            new_down_block_res_samples = ()

            for down_block_res_sample, down_block_additional_residual in zip(
                down_block_res_samples, down_block_additional_residuals
            ):
675
                down_block_res_sample = down_block_res_sample + down_block_additional_residual
676
677
678
679
                new_down_block_res_samples += (down_block_res_sample,)

            down_block_res_samples = new_down_block_res_samples

Patrick von Platen's avatar
Patrick von Platen committed
680
        # 4. mid
681
682
683
684
685
686
687
688
        if self.mid_block is not None:
            sample = self.mid_block(
                sample,
                emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                cross_attention_kwargs=cross_attention_kwargs,
            )
Patrick von Platen's avatar
Patrick von Platen committed
689

690
        if mid_block_additional_residual is not None:
691
            sample = sample + mid_block_additional_residual
692

Patrick von Platen's avatar
Patrick von Platen committed
693
        # 5. up
694
695
696
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

Patrick von Platen's avatar
Patrick von Platen committed
697
698
699
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

700
701
702
703
704
            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

705
            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
706
707
708
709
710
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
711
                    cross_attention_kwargs=cross_attention_kwargs,
712
                    upsample_size=upsample_size,
Will Berman's avatar
Will Berman committed
713
                    attention_mask=attention_mask,
Patrick von Platen's avatar
Patrick von Platen committed
714
715
                )
            else:
716
717
718
                sample = upsample_block(
                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
                )
719

Patrick von Platen's avatar
Patrick von Platen committed
720
        # 6. post-process
721
722
723
        if self.conv_norm_out:
            sample = self.conv_norm_out(sample)
            sample = self.conv_act(sample)
Patrick von Platen's avatar
Patrick von Platen committed
724
725
        sample = self.conv_out(sample)

726
727
        if not return_dict:
            return (sample,)
Patrick von Platen's avatar
Patrick von Platen committed
728

729
        return UNet2DConditionOutput(sample=sample)