unet_2d_condition.py 27.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from dataclasses import dataclass
15
from typing import Any, Dict, List, Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
16
17
18

import torch
import torch.nn as nn
19
import torch.utils.checkpoint
Patrick von Platen's avatar
Patrick von Platen committed
20
21

from ..configuration_utils import ConfigMixin, register_to_config
22
from ..loaders import UNet2DConditionLoadersMixin
23
from ..utils import BaseOutput, logging
24
from .cross_attention import AttnProcessor
25
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
26
from .modeling_utils import ModelMixin
27
from .unet_2d_blocks import (
28
29
30
31
    CrossAttnDownBlock2D,
    CrossAttnUpBlock2D,
    DownBlock2D,
    UNetMidBlock2DCrossAttn,
Will Berman's avatar
Will Berman committed
32
    UNetMidBlock2DSimpleCrossAttn,
33
34
35
36
    UpBlock2D,
    get_down_block,
    get_up_block,
)
Patrick von Platen's avatar
Patrick von Platen committed
37
38


39
40
41
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


42
43
44
45
46
47
48
49
50
51
52
@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
    """

    sample: torch.FloatTensor


53
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
Kashif Rasul's avatar
Kashif Rasul committed
54
55
56
57
58
    r"""
    UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
    and returns sample shaped output.

    This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
59
    implements for all the models (such as downloading or saving, etc.)
Kashif Rasul's avatar
Kashif Rasul committed
60
61

    Parameters:
62
63
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
Kashif Rasul's avatar
Kashif Rasul committed
64
65
66
        in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
        out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
Suraj Patil's avatar
Suraj Patil committed
67
        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Kashif Rasul's avatar
Kashif Rasul committed
68
69
70
71
            Whether to flip the sin to cos in the time embedding.
        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            The tuple of downsample blocks to use.
Will Berman's avatar
Will Berman committed
72
        mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
73
74
            The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
            mid block layer if `None`.
Kashif Rasul's avatar
Kashif Rasul committed
75
76
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
            The tuple of upsample blocks to use.
77
78
79
        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
            Whether to include self-attention in the basic transformer blocks, see
            [`~models.attention.BasicTransformerBlock`].
Kashif Rasul's avatar
Kashif Rasul committed
80
81
82
83
84
85
86
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
87
            If `None`, it will skip the normalization and activation layers in post-processing
Kashif Rasul's avatar
Kashif Rasul committed
88
89
90
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
        cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
Will Berman's avatar
Will Berman committed
91
92
93
94
        resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
            for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
        class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
            summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
95
96
97
        num_class_embeds (`int`, *optional*, defaults to None):
            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
            class conditioning with `class_embed_type` equal to `None`.
98
99
100
101
102
103
104
105
        time_embedding_type (`str`, *optional*, default to `positional`):
            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
        timestep_post_act (`str, *optional*, default to `None`):
            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
        time_cond_proj_dim (`int`, *optional*, default to `None`):
            The dimension of `cond_proj` layer in timestep embedding.
        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
        conv_out_kernel (`int`, *optional*, default to `3`): the Kernel size of `conv_out` layer.
Kashif Rasul's avatar
Kashif Rasul committed
106
107
    """

108
109
    _supports_gradient_checkpointing = True

Patrick von Platen's avatar
Patrick von Platen committed
110
111
112
    @register_to_config
    def __init__(
        self,
Sid Sahai's avatar
Sid Sahai committed
113
114
115
116
117
118
119
120
121
122
123
124
        sample_size: Optional[int] = None,
        in_channels: int = 4,
        out_channels: int = 4,
        center_input_sample: bool = False,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        down_block_types: Tuple[str] = (
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
125
        mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
Sid Sahai's avatar
Sid Sahai committed
126
        up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
127
        only_cross_attention: Union[bool, Tuple[bool]] = False,
Sid Sahai's avatar
Sid Sahai committed
128
129
130
131
132
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
        layers_per_block: int = 2,
        downsample_padding: int = 1,
        mid_block_scale_factor: float = 1,
        act_fn: str = "silu",
133
        norm_num_groups: Optional[int] = 32,
Sid Sahai's avatar
Sid Sahai committed
134
135
        norm_eps: float = 1e-5,
        cross_attention_dim: int = 1280,
Suraj Patil's avatar
Suraj Patil committed
136
        attention_head_dim: Union[int, Tuple[int]] = 8,
137
        dual_cross_attention: bool = False,
Suraj Patil's avatar
Suraj Patil committed
138
        use_linear_projection: bool = False,
Will Berman's avatar
Will Berman committed
139
        class_embed_type: Optional[str] = None,
140
        num_class_embeds: Optional[int] = None,
141
        upcast_attention: bool = False,
Will Berman's avatar
Will Berman committed
142
        resnet_time_scale_shift: str = "default",
143
144
145
146
147
        time_embedding_type: str = "positional",  # fourier, positional
        timestep_post_act: Optional[str] = None,
        time_cond_proj_dim: Optional[int] = None,
        conv_in_kernel: int = 3,
        conv_out_kernel: int = 3,
Patrick von Platen's avatar
Patrick von Platen committed
148
149
150
151
152
153
    ):
        super().__init__()

        self.sample_size = sample_size

        # input
154
155
156
157
        conv_in_padding = (conv_in_kernel - 1) // 2
        self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
        )
Patrick von Platen's avatar
Patrick von Platen committed
158
159

        # time
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        if time_embedding_type == "fourier":
            time_embed_dim = block_out_channels[0] * 2
            if time_embed_dim % 2 != 0:
                raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
            self.time_proj = GaussianFourierProjection(
                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
            )
            timestep_input_dim = time_embed_dim
        elif time_embedding_type == "positional":
            time_embed_dim = block_out_channels[0] * 4

            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
            timestep_input_dim = block_out_channels[0]
        else:
            raise ValueError(
                f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`."
            )
Patrick von Platen's avatar
Patrick von Platen committed
177

178
179
180
181
182
183
184
        self.time_embedding = TimestepEmbedding(
            timestep_input_dim,
            time_embed_dim,
            act_fn=act_fn,
            post_act_fn=timestep_post_act,
            cond_proj_dim=time_cond_proj_dim,
        )
Patrick von Platen's avatar
Patrick von Platen committed
185

186
        # class embedding
Will Berman's avatar
Will Berman committed
187
        if class_embed_type is None and num_class_embeds is not None:
188
            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
Will Berman's avatar
Will Berman committed
189
190
191
192
193
194
        elif class_embed_type == "timestep":
            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
        elif class_embed_type == "identity":
            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
        else:
            self.class_embedding = None
195

Patrick von Platen's avatar
Patrick von Platen committed
196
197
198
        self.down_blocks = nn.ModuleList([])
        self.up_blocks = nn.ModuleList([])

199
200
201
        if isinstance(only_cross_attention, bool):
            only_cross_attention = [only_cross_attention] * len(down_block_types)

Suraj Patil's avatar
Suraj Patil committed
202
203
204
        if isinstance(attention_head_dim, int):
            attention_head_dim = (attention_head_dim,) * len(down_block_types)

Patrick von Platen's avatar
Patrick von Platen committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
                num_layers=layers_per_block,
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=time_embed_dim,
                add_downsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
221
                resnet_groups=norm_num_groups,
222
                cross_attention_dim=cross_attention_dim,
Suraj Patil's avatar
Suraj Patil committed
223
                attn_num_head_channels=attention_head_dim[i],
Patrick von Platen's avatar
Patrick von Platen committed
224
                downsample_padding=downsample_padding,
225
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
226
                use_linear_projection=use_linear_projection,
227
                only_cross_attention=only_cross_attention[i],
228
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
229
                resnet_time_scale_shift=resnet_time_scale_shift,
Patrick von Platen's avatar
Patrick von Platen committed
230
231
232
233
            )
            self.down_blocks.append(down_block)

        # mid
Will Berman's avatar
Will Berman committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        if mid_block_type == "UNetMidBlock2DCrossAttn":
            self.mid_block = UNetMidBlock2DCrossAttn(
                in_channels=block_out_channels[-1],
                temb_channels=time_embed_dim,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
                resnet_time_scale_shift=resnet_time_scale_shift,
                cross_attention_dim=cross_attention_dim,
                attn_num_head_channels=attention_head_dim[-1],
                resnet_groups=norm_num_groups,
                dual_cross_attention=dual_cross_attention,
                use_linear_projection=use_linear_projection,
                upcast_attention=upcast_attention,
            )
        elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
            self.mid_block = UNetMidBlock2DSimpleCrossAttn(
                in_channels=block_out_channels[-1],
                temb_channels=time_embed_dim,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
                cross_attention_dim=cross_attention_dim,
                attn_num_head_channels=attention_head_dim[-1],
                resnet_groups=norm_num_groups,
                resnet_time_scale_shift=resnet_time_scale_shift,
            )
261
262
        elif mid_block_type is None:
            self.mid_block = None
Will Berman's avatar
Will Berman committed
263
264
        else:
            raise ValueError(f"unknown mid_block_type : {mid_block_type}")
Patrick von Platen's avatar
Patrick von Platen committed
265

266
267
268
        # count how many layers upsample the images
        self.num_upsamplers = 0

Patrick von Platen's avatar
Patrick von Platen committed
269
270
        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
Suraj Patil's avatar
Suraj Patil committed
271
        reversed_attention_head_dim = list(reversed(attention_head_dim))
272
        only_cross_attention = list(reversed(only_cross_attention))
273

Patrick von Platen's avatar
Patrick von Platen committed
274
275
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
276
277
            is_final_block = i == len(block_out_channels) - 1

Patrick von Platen's avatar
Patrick von Platen committed
278
279
280
281
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

282
283
284
285
286
287
            # add upsample block for all BUT final layer
            if not is_final_block:
                add_upsample = True
                self.num_upsamplers += 1
            else:
                add_upsample = False
Patrick von Platen's avatar
Patrick von Platen committed
288
289
290
291
292
293
294
295

            up_block = get_up_block(
                up_block_type,
                num_layers=layers_per_block + 1,
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
                temb_channels=time_embed_dim,
296
                add_upsample=add_upsample,
Patrick von Platen's avatar
Patrick von Platen committed
297
298
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
299
                resnet_groups=norm_num_groups,
300
                cross_attention_dim=cross_attention_dim,
Suraj Patil's avatar
Suraj Patil committed
301
                attn_num_head_channels=reversed_attention_head_dim[i],
302
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
303
                use_linear_projection=use_linear_projection,
304
                only_cross_attention=only_cross_attention[i],
305
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
306
                resnet_time_scale_shift=resnet_time_scale_shift,
Patrick von Platen's avatar
Patrick von Platen committed
307
308
309
310
311
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # out
312
313
314
315
316
317
318
319
320
321
322
323
324
        if norm_num_groups is not None:
            self.conv_norm_out = nn.GroupNorm(
                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
            )
            self.conv_act = nn.SiLU()
        else:
            self.conv_norm_out = None
            self.conv_act = None

        conv_out_padding = (conv_out_kernel - 1) // 2
        self.conv_out = nn.Conv2d(
            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
        )
Patrick von Platen's avatar
Patrick von Platen committed
325

326
327
328
329
330
331
332
    @property
    def attn_processors(self) -> Dict[str, AttnProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
333
        # set recursively
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        processors = {}

        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
            if hasattr(module, "set_processor"):
                processors[f"{name}.processor"] = module.processor

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

    def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
        r"""
        Parameters:
            `processor (`dict` of `AttnProcessor` or `AttnProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                of **all** `CrossAttention` layers.
            In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
368
            if hasattr(module, "set_processor"):
369
370
371
372
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))
373

374
375
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
376

377
378
        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)
379

380
    def set_attention_slice(self, slice_size):
381
382
        r"""
        Enable sliced attention computation.
383

384
385
        When this option is enabled, the attention module will split the input tensor in slices, to compute attention
        in several steps. This is useful to save some memory in exchange for a small speed decrease.
386

387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        Args:
            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
                When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
                `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
                must be a multiple of `slice_size`.
        """
        sliceable_head_dims = []

        def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
            if hasattr(module, "set_attention_slice"):
                sliceable_head_dims.append(module.sliceable_head_dim)

            for child in module.children():
                fn_recursive_retrieve_slicable_dims(child)

        # retrieve number of attention layers
        for module in self.children():
            fn_recursive_retrieve_slicable_dims(module)

        num_slicable_layers = len(sliceable_head_dims)

        if slice_size == "auto":
            # half the attention head size is usually a good trade-off between
            # speed and memory
            slice_size = [dim // 2 for dim in sliceable_head_dims]
        elif slice_size == "max":
            # make smallest slice possible
            slice_size = num_slicable_layers * [1]

        slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size

        if len(slice_size) != len(sliceable_head_dims):
            raise ValueError(
                f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
                f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
            )
424

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        for i in range(len(slice_size)):
            size = slice_size[i]
            dim = sliceable_head_dims[i]
            if size is not None and size > dim:
                raise ValueError(f"size {size} has to be smaller or equal to {dim}.")

        # Recursively walk through all the children.
        # Any children which exposes the set_attention_slice method
        # gets the message
        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
            if hasattr(module, "set_attention_slice"):
                module.set_attention_slice(slice_size.pop())

            for child in module.children():
                fn_recursive_set_attention_slice(child, slice_size)

        reversed_slice_size = list(reversed(slice_size))
        for module in self.children():
            fn_recursive_set_attention_slice(module, reversed_slice_size)
444

445
446
447
448
    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
            module.gradient_checkpointing = value

Patrick von Platen's avatar
Patrick von Platen committed
449
450
451
452
453
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
454
        class_labels: Optional[torch.Tensor] = None,
455
        timestep_cond: Optional[torch.Tensor] = None,
Will Berman's avatar
Will Berman committed
456
        attention_mask: Optional[torch.Tensor] = None,
457
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
458
459
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
460
        r"""
Kashif Rasul's avatar
Kashif Rasul committed
461
462
        Args:
            sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
463
            timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
464
            encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
Kashif Rasul's avatar
Kashif Rasul committed
465
466
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
467
468
469
470
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
                `self.processor` in
                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
Kashif Rasul's avatar
Kashif Rasul committed
471
472
473
474
475
476

        Returns:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
        """
477
478
479
480
481
482
483
484
485
486
487
488
489
490
        # By default samples have to be AT least a multiple of the overall upsampling factor.
        # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
        # However, the upsampling interpolation output size can be forced to fit any upsampling size
        # on the fly if necessary.
        default_overall_up_factor = 2**self.num_upsamplers

        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
        forward_upsample_size = False
        upsample_size = None

        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
            logger.info("Forward upsample size to force interpolation output size.")
            forward_upsample_size = True

Will Berman's avatar
Will Berman committed
491
492
493
494
495
        # prepare attention_mask
        if attention_mask is not None:
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
            attention_mask = attention_mask.unsqueeze(1)

Patrick von Platen's avatar
Patrick von Platen committed
496
497
498
499
500
501
502
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
503
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
504
505
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
Patrick von Platen's avatar
Patrick von Platen committed
506
            if isinstance(timestep, float):
507
508
509
510
511
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
512
            timesteps = timesteps[None].to(sample.device)
Patrick von Platen's avatar
Patrick von Platen committed
513

514
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
515
        timesteps = timesteps.expand(sample.shape[0])
516

Patrick von Platen's avatar
Patrick von Platen committed
517
        t_emb = self.time_proj(timesteps)
518
519
520
521
522

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=self.dtype)
523
524

        emb = self.time_embedding(t_emb, timestep_cond)
Patrick von Platen's avatar
Patrick von Platen committed
525

Will Berman's avatar
Will Berman committed
526
        if self.class_embedding is not None:
527
528
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")
Will Berman's avatar
Will Berman committed
529
530
531
532

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

533
534
535
            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
            emb = emb + class_emb

Patrick von Platen's avatar
Patrick von Platen committed
536
537
538
539
540
541
        # 2. pre-process
        sample = self.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
542
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
543
                sample, res_samples = downsample_block(
544
545
546
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
Will Berman's avatar
Will Berman committed
547
                    attention_mask=attention_mask,
548
                    cross_attention_kwargs=cross_attention_kwargs,
Patrick von Platen's avatar
Patrick von Platen committed
549
550
551
552
553
554
555
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

        # 4. mid
556
557
558
559
560
561
562
563
        if self.mid_block is not None:
            sample = self.mid_block(
                sample,
                emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=attention_mask,
                cross_attention_kwargs=cross_attention_kwargs,
            )
Patrick von Platen's avatar
Patrick von Platen committed
564
565

        # 5. up
566
567
568
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

Patrick von Platen's avatar
Patrick von Platen committed
569
570
571
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

572
573
574
575
576
            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

577
            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
578
579
580
581
582
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
583
                    cross_attention_kwargs=cross_attention_kwargs,
584
                    upsample_size=upsample_size,
Will Berman's avatar
Will Berman committed
585
                    attention_mask=attention_mask,
Patrick von Platen's avatar
Patrick von Platen committed
586
587
                )
            else:
588
589
590
                sample = upsample_block(
                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
                )
Patrick von Platen's avatar
Patrick von Platen committed
591
        # 6. post-process
592
593
594
        if self.conv_norm_out:
            sample = self.conv_norm_out(sample)
            sample = self.conv_act(sample)
Patrick von Platen's avatar
Patrick von Platen committed
595
596
        sample = self.conv_out(sample)

597
598
        if not return_dict:
            return (sample,)
Patrick von Platen's avatar
Patrick von Platen committed
599

600
        return UNet2DConditionOutput(sample=sample)