unet_2d_condition.py 24.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from dataclasses import dataclass
15
from typing import Any, Dict, List, Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
16
17
18

import torch
import torch.nn as nn
19
import torch.utils.checkpoint
Patrick von Platen's avatar
Patrick von Platen committed
20
21

from ..configuration_utils import ConfigMixin, register_to_config
22
from ..loaders import UNet2DConditionLoadersMixin
23
from ..utils import BaseOutput, logging
24
from .cross_attention import AttnProcessor
Patrick von Platen's avatar
Patrick von Platen committed
25
from .embeddings import TimestepEmbedding, Timesteps
26
from .modeling_utils import ModelMixin
27
from .unet_2d_blocks import (
28
29
30
31
    CrossAttnDownBlock2D,
    CrossAttnUpBlock2D,
    DownBlock2D,
    UNetMidBlock2DCrossAttn,
Will Berman's avatar
Will Berman committed
32
    UNetMidBlock2DSimpleCrossAttn,
33
34
35
36
    UpBlock2D,
    get_down_block,
    get_up_block,
)
Patrick von Platen's avatar
Patrick von Platen committed
37
38


39
40
41
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


42
43
44
45
46
47
48
49
50
51
52
@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
    """

    sample: torch.FloatTensor


53
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
Kashif Rasul's avatar
Kashif Rasul committed
54
55
56
57
58
    r"""
    UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
    and returns sample shaped output.

    This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
59
    implements for all the models (such as downloading or saving, etc.)
Kashif Rasul's avatar
Kashif Rasul committed
60
61

    Parameters:
62
63
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
Kashif Rasul's avatar
Kashif Rasul committed
64
65
66
        in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
        out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
Suraj Patil's avatar
Suraj Patil committed
67
        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Kashif Rasul's avatar
Kashif Rasul committed
68
69
70
71
            Whether to flip the sin to cos in the time embedding.
        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            The tuple of downsample blocks to use.
Will Berman's avatar
Will Berman committed
72
73
        mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
            The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`.
Kashif Rasul's avatar
Kashif Rasul committed
74
75
76
77
78
79
80
81
82
83
84
85
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
            The tuple of upsample blocks to use.
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
        cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
Will Berman's avatar
Will Berman committed
86
87
88
89
        resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
            for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
        class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
            summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
90
91
92
        num_class_embeds (`int`, *optional*, defaults to None):
            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
            class conditioning with `class_embed_type` equal to `None`.
Kashif Rasul's avatar
Kashif Rasul committed
93
94
    """

95
96
    _supports_gradient_checkpointing = True

Patrick von Platen's avatar
Patrick von Platen committed
97
98
99
    @register_to_config
    def __init__(
        self,
Sid Sahai's avatar
Sid Sahai committed
100
101
102
103
104
105
106
107
108
109
110
111
        sample_size: Optional[int] = None,
        in_channels: int = 4,
        out_channels: int = 4,
        center_input_sample: bool = False,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        down_block_types: Tuple[str] = (
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
Will Berman's avatar
Will Berman committed
112
        mid_block_type: str = "UNetMidBlock2DCrossAttn",
Sid Sahai's avatar
Sid Sahai committed
113
        up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
114
        only_cross_attention: Union[bool, Tuple[bool]] = False,
Sid Sahai's avatar
Sid Sahai committed
115
116
117
118
119
120
121
122
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
        layers_per_block: int = 2,
        downsample_padding: int = 1,
        mid_block_scale_factor: float = 1,
        act_fn: str = "silu",
        norm_num_groups: int = 32,
        norm_eps: float = 1e-5,
        cross_attention_dim: int = 1280,
Suraj Patil's avatar
Suraj Patil committed
123
        attention_head_dim: Union[int, Tuple[int]] = 8,
124
        dual_cross_attention: bool = False,
Suraj Patil's avatar
Suraj Patil committed
125
        use_linear_projection: bool = False,
Will Berman's avatar
Will Berman committed
126
        class_embed_type: Optional[str] = None,
127
        num_class_embeds: Optional[int] = None,
128
        upcast_attention: bool = False,
Will Berman's avatar
Will Berman committed
129
        resnet_time_scale_shift: str = "default",
Patrick von Platen's avatar
Patrick von Platen committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    ):
        super().__init__()

        self.sample_size = sample_size
        time_embed_dim = block_out_channels[0] * 4

        # input
        self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))

        # time
        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
        timestep_input_dim = block_out_channels[0]

        self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

145
        # class embedding
Will Berman's avatar
Will Berman committed
146
        if class_embed_type is None and num_class_embeds is not None:
147
            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
Will Berman's avatar
Will Berman committed
148
149
150
151
152
153
        elif class_embed_type == "timestep":
            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
        elif class_embed_type == "identity":
            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
        else:
            self.class_embedding = None
154

Patrick von Platen's avatar
Patrick von Platen committed
155
156
157
158
        self.down_blocks = nn.ModuleList([])
        self.mid_block = None
        self.up_blocks = nn.ModuleList([])

159
160
161
        if isinstance(only_cross_attention, bool):
            only_cross_attention = [only_cross_attention] * len(down_block_types)

Suraj Patil's avatar
Suraj Patil committed
162
163
164
        if isinstance(attention_head_dim, int):
            attention_head_dim = (attention_head_dim,) * len(down_block_types)

Patrick von Platen's avatar
Patrick von Platen committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
                num_layers=layers_per_block,
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=time_embed_dim,
                add_downsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
181
                resnet_groups=norm_num_groups,
182
                cross_attention_dim=cross_attention_dim,
Suraj Patil's avatar
Suraj Patil committed
183
                attn_num_head_channels=attention_head_dim[i],
Patrick von Platen's avatar
Patrick von Platen committed
184
                downsample_padding=downsample_padding,
185
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
186
                use_linear_projection=use_linear_projection,
187
                only_cross_attention=only_cross_attention[i],
188
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
189
                resnet_time_scale_shift=resnet_time_scale_shift,
Patrick von Platen's avatar
Patrick von Platen committed
190
191
192
193
            )
            self.down_blocks.append(down_block)

        # mid
Will Berman's avatar
Will Berman committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        if mid_block_type == "UNetMidBlock2DCrossAttn":
            self.mid_block = UNetMidBlock2DCrossAttn(
                in_channels=block_out_channels[-1],
                temb_channels=time_embed_dim,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
                resnet_time_scale_shift=resnet_time_scale_shift,
                cross_attention_dim=cross_attention_dim,
                attn_num_head_channels=attention_head_dim[-1],
                resnet_groups=norm_num_groups,
                dual_cross_attention=dual_cross_attention,
                use_linear_projection=use_linear_projection,
                upcast_attention=upcast_attention,
            )
        elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
            self.mid_block = UNetMidBlock2DSimpleCrossAttn(
                in_channels=block_out_channels[-1],
                temb_channels=time_embed_dim,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                output_scale_factor=mid_block_scale_factor,
                cross_attention_dim=cross_attention_dim,
                attn_num_head_channels=attention_head_dim[-1],
                resnet_groups=norm_num_groups,
                resnet_time_scale_shift=resnet_time_scale_shift,
            )
        else:
            raise ValueError(f"unknown mid_block_type : {mid_block_type}")
Patrick von Platen's avatar
Patrick von Platen committed
223

224
225
226
        # count how many layers upsample the images
        self.num_upsamplers = 0

Patrick von Platen's avatar
Patrick von Platen committed
227
228
        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
Suraj Patil's avatar
Suraj Patil committed
229
        reversed_attention_head_dim = list(reversed(attention_head_dim))
230
        only_cross_attention = list(reversed(only_cross_attention))
Patrick von Platen's avatar
Patrick von Platen committed
231
232
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
233
234
            is_final_block = i == len(block_out_channels) - 1

Patrick von Platen's avatar
Patrick von Platen committed
235
236
237
238
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

239
240
241
242
243
244
            # add upsample block for all BUT final layer
            if not is_final_block:
                add_upsample = True
                self.num_upsamplers += 1
            else:
                add_upsample = False
Patrick von Platen's avatar
Patrick von Platen committed
245
246
247
248
249
250
251
252

            up_block = get_up_block(
                up_block_type,
                num_layers=layers_per_block + 1,
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
                temb_channels=time_embed_dim,
253
                add_upsample=add_upsample,
Patrick von Platen's avatar
Patrick von Platen committed
254
255
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
256
                resnet_groups=norm_num_groups,
257
                cross_attention_dim=cross_attention_dim,
Suraj Patil's avatar
Suraj Patil committed
258
                attn_num_head_channels=reversed_attention_head_dim[i],
259
                dual_cross_attention=dual_cross_attention,
Suraj Patil's avatar
Suraj Patil committed
260
                use_linear_projection=use_linear_projection,
261
                only_cross_attention=only_cross_attention[i],
262
                upcast_attention=upcast_attention,
Will Berman's avatar
Will Berman committed
263
                resnet_time_scale_shift=resnet_time_scale_shift,
Patrick von Platen's avatar
Patrick von Platen committed
264
265
266
267
268
269
270
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # out
        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
        self.conv_act = nn.SiLU()
271
        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
Patrick von Platen's avatar
Patrick von Platen committed
272

273
274
275
276
277
278
279
    @property
    def attn_processors(self) -> Dict[str, AttnProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
280
        # set recursively
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        processors = {}

        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
            if hasattr(module, "set_processor"):
                processors[f"{name}.processor"] = module.processor

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

    def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
        r"""
        Parameters:
            `processor (`dict` of `AttnProcessor` or `AttnProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                of **all** `CrossAttention` layers.
            In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
315
            if hasattr(module, "set_processor"):
316
317
318
319
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))
320

321
322
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
323

324
325
        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)
326

327
    def set_attention_slice(self, slice_size):
328
329
        r"""
        Enable sliced attention computation.
330

331
332
        When this option is enabled, the attention module will split the input tensor in slices, to compute attention
        in several steps. This is useful to save some memory in exchange for a small speed decrease.
333

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        Args:
            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
                When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
                `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
                must be a multiple of `slice_size`.
        """
        sliceable_head_dims = []

        def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
            if hasattr(module, "set_attention_slice"):
                sliceable_head_dims.append(module.sliceable_head_dim)

            for child in module.children():
                fn_recursive_retrieve_slicable_dims(child)

        # retrieve number of attention layers
        for module in self.children():
            fn_recursive_retrieve_slicable_dims(module)

        num_slicable_layers = len(sliceable_head_dims)

        if slice_size == "auto":
            # half the attention head size is usually a good trade-off between
            # speed and memory
            slice_size = [dim // 2 for dim in sliceable_head_dims]
        elif slice_size == "max":
            # make smallest slice possible
            slice_size = num_slicable_layers * [1]

        slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size

        if len(slice_size) != len(sliceable_head_dims):
            raise ValueError(
                f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
                f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
            )
371

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        for i in range(len(slice_size)):
            size = slice_size[i]
            dim = sliceable_head_dims[i]
            if size is not None and size > dim:
                raise ValueError(f"size {size} has to be smaller or equal to {dim}.")

        # Recursively walk through all the children.
        # Any children which exposes the set_attention_slice method
        # gets the message
        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
            if hasattr(module, "set_attention_slice"):
                module.set_attention_slice(slice_size.pop())

            for child in module.children():
                fn_recursive_set_attention_slice(child, slice_size)

        reversed_slice_size = list(reversed(slice_size))
        for module in self.children():
            fn_recursive_set_attention_slice(module, reversed_slice_size)
391

392
393
394
395
    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
            module.gradient_checkpointing = value

Patrick von Platen's avatar
Patrick von Platen committed
396
397
398
399
400
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
401
        class_labels: Optional[torch.Tensor] = None,
Will Berman's avatar
Will Berman committed
402
        attention_mask: Optional[torch.Tensor] = None,
403
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
404
405
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
406
        r"""
Kashif Rasul's avatar
Kashif Rasul committed
407
408
        Args:
            sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
409
            timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
410
            encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
Kashif Rasul's avatar
Kashif Rasul committed
411
412
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
413
414
415
416
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
                `self.processor` in
                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
Kashif Rasul's avatar
Kashif Rasul committed
417
418
419
420
421
422

        Returns:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
        """
423
424
425
426
427
428
429
430
431
432
433
434
435
436
        # By default samples have to be AT least a multiple of the overall upsampling factor.
        # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
        # However, the upsampling interpolation output size can be forced to fit any upsampling size
        # on the fly if necessary.
        default_overall_up_factor = 2**self.num_upsamplers

        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
        forward_upsample_size = False
        upsample_size = None

        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
            logger.info("Forward upsample size to force interpolation output size.")
            forward_upsample_size = True

Will Berman's avatar
Will Berman committed
437
438
439
440
441
        # prepare attention_mask
        if attention_mask is not None:
            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
            attention_mask = attention_mask.unsqueeze(1)

Patrick von Platen's avatar
Patrick von Platen committed
442
443
444
445
446
447
448
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
449
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
450
451
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
Patrick von Platen's avatar
Patrick von Platen committed
452
            if isinstance(timestep, float):
453
454
455
456
457
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
458
            timesteps = timesteps[None].to(sample.device)
Patrick von Platen's avatar
Patrick von Platen committed
459

460
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
461
        timesteps = timesteps.expand(sample.shape[0])
462

Patrick von Platen's avatar
Patrick von Platen committed
463
        t_emb = self.time_proj(timesteps)
464
465
466
467
468
469

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=self.dtype)
        emb = self.time_embedding(t_emb)
Patrick von Platen's avatar
Patrick von Platen committed
470

Will Berman's avatar
Will Berman committed
471
        if self.class_embedding is not None:
472
473
            if class_labels is None:
                raise ValueError("class_labels should be provided when num_class_embeds > 0")
Will Berman's avatar
Will Berman committed
474
475
476
477

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

478
479
480
            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
            emb = emb + class_emb

Patrick von Platen's avatar
Patrick von Platen committed
481
482
483
484
485
486
        # 2. pre-process
        sample = self.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
487
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
488
                sample, res_samples = downsample_block(
489
490
491
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
Will Berman's avatar
Will Berman committed
492
                    attention_mask=attention_mask,
493
                    cross_attention_kwargs=cross_attention_kwargs,
Patrick von Platen's avatar
Patrick von Platen committed
494
495
496
497
498
499
500
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

        # 4. mid
Will Berman's avatar
Will Berman committed
501
        sample = self.mid_block(
502
503
504
505
506
            sample,
            emb,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            cross_attention_kwargs=cross_attention_kwargs,
Will Berman's avatar
Will Berman committed
507
        )
Patrick von Platen's avatar
Patrick von Platen committed
508
509

        # 5. up
510
511
512
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

Patrick von Platen's avatar
Patrick von Platen committed
513
514
515
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

516
517
518
519
520
            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

521
            if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
Patrick von Platen's avatar
Patrick von Platen committed
522
523
524
525
526
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
527
                    cross_attention_kwargs=cross_attention_kwargs,
528
                    upsample_size=upsample_size,
Will Berman's avatar
Will Berman committed
529
                    attention_mask=attention_mask,
Patrick von Platen's avatar
Patrick von Platen committed
530
531
                )
            else:
532
533
534
                sample = upsample_block(
                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
                )
Patrick von Platen's avatar
Patrick von Platen committed
535
        # 6. post-process
536
        sample = self.conv_norm_out(sample)
Patrick von Platen's avatar
Patrick von Platen committed
537
538
539
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

540
541
        if not return_dict:
            return (sample,)
Patrick von Platen's avatar
Patrick von Platen committed
542

543
        return UNet2DConditionOutput(sample=sample)