unet_2d_condition.py 21.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from dataclasses import dataclass
15
from typing import List, Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
16
17
18

import torch
import torch.nn as nn
19
import torch.utils.checkpoint
Patrick von Platen's avatar
Patrick von Platen committed
20
21
22

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
23
from ..utils import BaseOutput, logging
Patrick von Platen's avatar
Patrick von Platen committed
24
from .embeddings import TimestepEmbedding, Timesteps
25
from .unet_2d_blocks import (
26
27
28
29
    CrossAttnDownBlock2D,
    CrossAttnUpBlock2D,
    DownBlock2D,
    UNetMidBlock2DCrossAttn,
Will Berman's avatar
Will Berman committed
30
    UNetMidBlock2DSimpleCrossAttn,
31
32
33
34
    UpBlock2D,
    get_down_block,
    get_up_block,
)
Patrick von Platen's avatar
Patrick von Platen committed
35
36


37
38
39
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


40
41
42
43
44
45
46
47
48
49
50
@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
    """

    sample: torch.FloatTensor


Patrick von Platen's avatar
Patrick von Platen committed
51
class UNet2DConditionModel(ModelMixin, ConfigMixin):
Kashif Rasul's avatar
Kashif Rasul committed
52
53
54
55
56
    r"""
    UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
    and returns sample shaped output.

    This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
57
    implements for all the models (such as downloading or saving, etc.)
Kashif Rasul's avatar
Kashif Rasul committed
58
59

    Parameters:
60
61
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
Kashif Rasul's avatar
Kashif Rasul committed
62
63
64
        in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
        out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
Suraj Patil's avatar
Suraj Patil committed
65
        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Kashif Rasul's avatar
Kashif Rasul committed
66
67
68
69
            Whether to flip the sin to cos in the time embedding.
        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            The tuple of downsample blocks to use.
Will Berman's avatar
Will Berman committed
70
71
        mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
            The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`.
Kashif Rasul's avatar
Kashif Rasul committed
72
73
74
75
76
77
78
79
80
81
82
83
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
            The tuple of upsample blocks to use.
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
        cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
Will Berman's avatar
Will Berman committed
84
85
86
87
        resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
            for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
        class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
            summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
Kashif Rasul's avatar
Kashif Rasul committed
88
89
    """

90
91
    _supports_gradient_checkpointing = True

Patrick von Platen's avatar
Patrick von Platen committed
92
93
94
    @register_to_config
    def __init__(
        self,
Sid Sahai's avatar
Sid Sahai committed
95
96
97
98
99
100
101
102
103
104
105
106
        sample_size: Optional[int] = None,
        in_channels: int = 4,
        out_channels: int = 4,
        center_input_sample: bool = False,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        down_block_types: Tuple[str] = (
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
Will Berman's avatar
Will Berman committed
107
        mid_block_type: str = "UNetMidBlock2DCrossAttn",
Sid Sahai's avatar
Sid Sahai committed
108
        up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
109
        only_cross_attention: Union[bool, Tuple[bool]] = False,
Sid Sahai's avatar
Sid Sahai committed
110
111
112
113
114
115
116
117
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
        layers_per_block: int = 2,
        downsample_padding: int = 1,
        mid_block_scale_factor: float = 1,
        act_fn: str = "silu",
        norm_num_groups: int = 32,
        norm_eps: float = 1e-5,
        cross_attention_dim: int = 1280,
Suraj Patil's avatar
Suraj Patil committed
118
        attention_head_dim: Union[int, Tuple[int]] = 8,
119
        dual_cross_attention: bool = False,
Suraj Patil's avatar
Suraj Patil committed
120
        use_linear_projection: bool = False,
Will Berman's avatar
Will Berman committed
121
        class_embed_type: Optional[str] = None,
122
        num_class_embeds: Optional[int] = None,
123
        upcast_attention: bool = False,
Will Berman's avatar
Will Berman committed
124
        resnet_time_scale_shift: str = "default",
Patrick von Platen's avatar
Patrick von Platen committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    ):
        super().__init__()

        self.sample_size = sample_size
        time_embed_dim = block_out_channels[0] * 4

        # input
        self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))

        # time
        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
        timestep_input_dim = block_out_channels[0]

        self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

140
        # class embedding
Will Berman's avatar
Will Berman committed
141
        if class_embed_type is None and num_class_embeds is not None:
142
            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
Will Berman's avatar
Will Berman committed
143
144
145
146
147
148
        elif class_embed_type == "timestep":
            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
        elif class_embed_type == "identity":
            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
        else:
            self.class_embedding = None
149

Patrick von Platen's avatar
Patrick von Platen committed
150
151
152
153
        self.down_blocks = nn.ModuleList([])
        self.mid_block = None
        self.up_blocks = nn.ModuleList([])

154
155
156
        if isinstance(only_cross_attention, bool):
            only_cross_attention = [only_cross_attention] * len(down_block_types)

Suraj Patil's avatar
Suraj Patil committed
157
158
159
        if isinstance(attention_head_dim, int):
            attention_head_dim = (attention_head_dim,) * len(down_block_types)

Patrick von Platen's avatar
Patrick von Platen committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
                num_layers=layers_per_block,
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=time_embed_dim,
                add_downsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
176
                resnet_groups=norm_num_groups,
177
                cross_attention_dim=cross_attention_dim,
Suraj Patil's avatar
Suraj Patil committed
178
                attn_num_head_channels=attention_head_dim[i],
Patrick von Platen's avatar
Patrick von Platen committed
179
                downsample_padding=downsample_padding,
180
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
181
                use_linear_projection=use_linear_projection,
182
                only_cross_attention=only_cross_attention[i],
183
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
184
                resnet_time_scale_shift=resnet_time_scale_shift,
Patrick von Platen's avatar
Patrick von Platen committed
185
186
187
188
            )
            self.down_blocks.append(down_block)

        # mid
Will Berman's avatar
Will Berman committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        if mid_block_type == "UNetMidBlock2DCrossAttn":
            self.mid_block = UNetMidBlock2DCrossAttn(
                in_channels=block_out_channels[-1],
                temb_channels=time_embed_dim,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
                resnet_time_scale_shift=resnet_time_scale_shift,
                cross_attention_dim=cross_attention_dim,
                attn_num_head_channels=attention_head_dim[-1],
                resnet_groups=norm_num_groups,
                dual_cross_attention=dual_cross_attention,
                use_linear_projection=use_linear_projection,
                upcast_attention=upcast_attention,
            )
        elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
            self.mid_block = UNetMidBlock2DSimpleCrossAttn(
                in_channels=block_out_channels[-1],
                temb_channels=time_embed_dim,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
                cross_attention_dim=cross_attention_dim,
                attn_num_head_channels=attention_head_dim[-1],
                resnet_groups=norm_num_groups,
                resnet_time_scale_shift=resnet_time_scale_shift,
            )
        else:
            raise ValueError(f"unknown mid_block_type : {mid_block_type}")
Patrick von Platen's avatar
Patrick von Platen committed
218

219
220
221
        # count how many layers upsample the images
        self.num_upsamplers = 0

Patrick von Platen's avatar
Patrick von Platen committed
222
223
        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
Suraj Patil's avatar
Suraj Patil committed
224
        reversed_attention_head_dim = list(reversed(attention_head_dim))
225
        only_cross_attention = list(reversed(only_cross_attention))
Patrick von Platen's avatar
Patrick von Platen committed
226
227
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
228
229
            is_final_block = i == len(block_out_channels) - 1

Patrick von Platen's avatar
Patrick von Platen committed
230
231
232
233
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

234
235
236
237
238
239
            # add upsample block for all BUT final layer
            if not is_final_block:
                add_upsample = True
                self.num_upsamplers += 1
            else:
                add_upsample = False
Patrick von Platen's avatar
Patrick von Platen committed
240
241
242
243
244
245
246
247

            up_block = get_up_block(
                up_block_type,
                num_layers=layers_per_block + 1,
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
                temb_channels=time_embed_dim,
248
                add_upsample=add_upsample,
Patrick von Platen's avatar
Patrick von Platen committed
249
250
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
251
                resnet_groups=norm_num_groups,
252
                cross_attention_dim=cross_attention_dim,
Suraj Patil's avatar
Suraj Patil committed
253
                attn_num_head_channels=reversed_attention_head_dim[i],
254
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
255
                use_linear_projection=use_linear_projection,
256
                only_cross_attention=only_cross_attention[i],
257
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
258
                resnet_time_scale_shift=resnet_time_scale_shift,
Patrick von Platen's avatar
Patrick von Platen committed
259
260
261
262
263
264
265
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # out
        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
        self.conv_act = nn.SiLU()
266
        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
Patrick von Platen's avatar
Patrick von Platen committed
267

268
    def set_attention_slice(self, slice_size):
269
270
        r"""
        Enable sliced attention computation.
271

272
273
        When this option is enabled, the attention module will split the input tensor in slices, to compute attention
        in several steps. This is useful to save some memory in exchange for a small speed decrease.
274

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        Args:
            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
                When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
                `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
                must be a multiple of `slice_size`.
        """
        sliceable_head_dims = []

        def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
            if hasattr(module, "set_attention_slice"):
                sliceable_head_dims.append(module.sliceable_head_dim)

            for child in module.children():
                fn_recursive_retrieve_slicable_dims(child)

        # retrieve number of attention layers
        for module in self.children():
            fn_recursive_retrieve_slicable_dims(module)

        num_slicable_layers = len(sliceable_head_dims)

        if slice_size == "auto":
            # half the attention head size is usually a good trade-off between
            # speed and memory
            slice_size = [dim // 2 for dim in sliceable_head_dims]
        elif slice_size == "max":
            # make smallest slice possible
            slice_size = num_slicable_layers * [1]

        slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size

        if len(slice_size) != len(sliceable_head_dims):
            raise ValueError(
                f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
                f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
            )
312

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
        for i in range(len(slice_size)):
            size = slice_size[i]
            dim = sliceable_head_dims[i]
            if size is not None and size > dim:
                raise ValueError(f"size {size} has to be smaller or equal to {dim}.")

        # Recursively walk through all the children.
        # Any children which exposes the set_attention_slice method
        # gets the message
        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
            if hasattr(module, "set_attention_slice"):
                module.set_attention_slice(slice_size.pop())

            for child in module.children():
                fn_recursive_set_attention_slice(child, slice_size)

        reversed_slice_size = list(reversed(slice_size))
        for module in self.children():
            fn_recursive_set_attention_slice(module, reversed_slice_size)
332

333
334
335
336
    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
            module.gradient_checkpointing = value

Patrick von Platen's avatar
Patrick von Platen committed
337
338
339
340
341
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
342
        class_labels: Optional[torch.Tensor] = None,
Will Berman's avatar
Will Berman committed
343
        attention_mask: Optional[torch.Tensor] = None,
344
345
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
346
        r"""
Kashif Rasul's avatar
Kashif Rasul committed
347
348
        Args:
            sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
349
            timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
350
            encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
Kashif Rasul's avatar
Kashif Rasul committed
351
352
353
354
355
356
357
358
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.

        Returns:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
        """
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        # By default samples have to be AT least a multiple of the overall upsampling factor.
        # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
        # However, the upsampling interpolation output size can be forced to fit any upsampling size
        # on the fly if necessary.
        default_overall_up_factor = 2**self.num_upsamplers

        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
        forward_upsample_size = False
        upsample_size = None

        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
            logger.info("Forward upsample size to force interpolation output size.")
            forward_upsample_size = True

Will Berman's avatar
Will Berman committed
373
374
375
376
377
        # prepare attention_mask
        if attention_mask is not None:
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
            attention_mask = attention_mask.unsqueeze(1)

Patrick von Platen's avatar
Patrick von Platen committed
378
379
380
381
382
383
384
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
385
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
386
387
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
Patrick von Platen's avatar
Patrick von Platen committed
388
            if isinstance(timestep, float):
389
390
391
392
393
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
394
            timesteps = timesteps[None].to(sample.device)
Patrick von Platen's avatar
Patrick von Platen committed
395

396
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
397
        timesteps = timesteps.expand(sample.shape[0])
398

Patrick von Platen's avatar
Patrick von Platen committed
399
        t_emb = self.time_proj(timesteps)
400
401
402
403
404
405

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=self.dtype)
        emb = self.time_embedding(t_emb)
Patrick von Platen's avatar
Patrick von Platen committed
406

Will Berman's avatar
Will Berman committed
407
        if self.class_embedding is not None:
408
409
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")
Will Berman's avatar
Will Berman committed
410
411
412
413

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

414
415
416
            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
            emb = emb + class_emb

Patrick von Platen's avatar
Patrick von Platen committed
417
418
419
420
421
422
        # 2. pre-process
        sample = self.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
423
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
424
                sample, res_samples = downsample_block(
425
426
427
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
Will Berman's avatar
Will Berman committed
428
                    attention_mask=attention_mask,
Patrick von Platen's avatar
Patrick von Platen committed
429
430
431
432
433
434
435
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

        # 4. mid
Will Berman's avatar
Will Berman committed
436
437
438
        sample = self.mid_block(
            sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
        )
Patrick von Platen's avatar
Patrick von Platen committed
439
440

        # 5. up
441
442
443
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

Patrick von Platen's avatar
Patrick von Platen committed
444
445
446
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

447
448
449
450
451
            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

452
            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
453
454
455
456
457
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
458
                    upsample_size=upsample_size,
Will Berman's avatar
Will Berman committed
459
                    attention_mask=attention_mask,
Patrick von Platen's avatar
Patrick von Platen committed
460
461
                )
            else:
462
463
464
                sample = upsample_block(
                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
                )
Patrick von Platen's avatar
Patrick von Platen committed
465
        # 6. post-process
466
        sample = self.conv_norm_out(sample)
Patrick von Platen's avatar
Patrick von Platen committed
467
468
469
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

470
471
        if not return_dict:
            return (sample,)
Patrick von Platen's avatar
Patrick von Platen committed
472

473
        return UNet2DConditionOutput(sample=sample)