unet_2d_condition.py 15.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
from dataclasses import dataclass
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
16
17
18

import torch
import torch.nn as nn
19
import torch.utils.checkpoint
Patrick von Platen's avatar
Patrick von Platen committed
20
21
22

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
23
from ..utils import BaseOutput, logging
Patrick von Platen's avatar
Patrick von Platen committed
24
from .embeddings import TimestepEmbedding, Timesteps
25
from .unet_2d_blocks import (
26
27
28
29
30
31
32
33
    CrossAttnDownBlock2D,
    CrossAttnUpBlock2D,
    DownBlock2D,
    UNetMidBlock2DCrossAttn,
    UpBlock2D,
    get_down_block,
    get_up_block,
)
Patrick von Platen's avatar
Patrick von Platen committed
34
35


36
37
38
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


39
40
41
42
43
44
45
46
47
48
49
@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
    """

    sample: torch.FloatTensor


Patrick von Platen's avatar
Patrick von Platen committed
50
class UNet2DConditionModel(ModelMixin, ConfigMixin):
Kashif Rasul's avatar
Kashif Rasul committed
51
52
53
54
55
    r"""
    UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
    and returns sample shaped output.

    This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
56
    implements for all the models (such as downloading or saving, etc.)
Kashif Rasul's avatar
Kashif Rasul committed
57
58

    Parameters:
59
60
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
Kashif Rasul's avatar
Kashif Rasul committed
61
62
63
        in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
        out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
64
        flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Kashif Rasul's avatar
Kashif Rasul committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
            Whether to flip the sin to cos in the time embedding.
        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            The tuple of downsample blocks to use.
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
            The tuple of upsample blocks to use.
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
        cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
    """

83
84
    _supports_gradient_checkpointing = True

Patrick von Platen's avatar
Patrick von Platen committed
85
86
87
    @register_to_config
    def __init__(
        self,
Sid Sahai's avatar
Sid Sahai committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        sample_size: Optional[int] = None,
        in_channels: int = 4,
        out_channels: int = 4,
        center_input_sample: bool = False,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        down_block_types: Tuple[str] = (
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
        up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
        layers_per_block: int = 2,
        downsample_padding: int = 1,
        mid_block_scale_factor: float = 1,
        act_fn: str = "silu",
        norm_num_groups: int = 32,
        norm_eps: float = 1e-5,
        cross_attention_dim: int = 1280,
        attention_head_dim: int = 8,
110
        dual_cross_attention: bool = False,
Patrick von Platen's avatar
Patrick von Platen committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    ):
        super().__init__()

        self.sample_size = sample_size
        time_embed_dim = block_out_channels[0] * 4

        # input
        self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))

        # time
        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
        timestep_input_dim = block_out_channels[0]

        self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

        self.down_blocks = nn.ModuleList([])
        self.mid_block = None
        self.up_blocks = nn.ModuleList([])

        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
                num_layers=layers_per_block,
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=time_embed_dim,
                add_downsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
146
                resnet_groups=norm_num_groups,
147
                cross_attention_dim=cross_attention_dim,
Patrick von Platen's avatar
Patrick von Platen committed
148
149
                attn_num_head_channels=attention_head_dim,
                downsample_padding=downsample_padding,
150
                dual_cross_attention=dual_cross_attention,
Patrick von Platen's avatar
Patrick von Platen committed
151
152
153
154
155
156
157
158
159
160
161
            )
            self.down_blocks.append(down_block)

        # mid
        self.mid_block = UNetMidBlock2DCrossAttn(
            in_channels=block_out_channels[-1],
            temb_channels=time_embed_dim,
            resnet_eps=norm_eps,
            resnet_act_fn=act_fn,
            output_scale_factor=mid_block_scale_factor,
            resnet_time_scale_shift="default",
162
            cross_attention_dim=cross_attention_dim,
Patrick von Platen's avatar
Patrick von Platen committed
163
164
            attn_num_head_channels=attention_head_dim,
            resnet_groups=norm_num_groups,
165
            dual_cross_attention=dual_cross_attention,
Patrick von Platen's avatar
Patrick von Platen committed
166
167
        )

168
169
170
        # count how many layers upsample the images
        self.num_upsamplers = 0

Patrick von Platen's avatar
Patrick von Platen committed
171
172
173
174
        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
175
176
            is_final_block = i == len(block_out_channels) - 1

Patrick von Platen's avatar
Patrick von Platen committed
177
178
179
180
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

181
182
183
184
185
186
            # add upsample block for all BUT final layer
            if not is_final_block:
                add_upsample = True
                self.num_upsamplers += 1
            else:
                add_upsample = False
Patrick von Platen's avatar
Patrick von Platen committed
187
188
189
190
191
192
193
194

            up_block = get_up_block(
                up_block_type,
                num_layers=layers_per_block + 1,
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
                temb_channels=time_embed_dim,
195
                add_upsample=add_upsample,
Patrick von Platen's avatar
Patrick von Platen committed
196
197
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
198
                resnet_groups=norm_num_groups,
199
                cross_attention_dim=cross_attention_dim,
Patrick von Platen's avatar
Patrick von Platen committed
200
                attn_num_head_channels=attention_head_dim,
201
                dual_cross_attention=dual_cross_attention,
Patrick von Platen's avatar
Patrick von Platen committed
202
203
204
205
206
207
208
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # out
        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
        self.conv_act = nn.SiLU()
209
        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
Patrick von Platen's avatar
Patrick von Platen committed
210

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    def set_attention_slice(self, slice_size):
        if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
            raise ValueError(
                f"Make sure slice_size {slice_size} is a divisor of "
                f"the number of heads used in cross_attention {self.config.attention_head_dim}"
            )
        if slice_size is not None and slice_size > self.config.attention_head_dim:
            raise ValueError(
                f"Chunk_size {slice_size} has to be smaller or equal to "
                f"the number of heads used in cross_attention {self.config.attention_head_dim}"
            )

        for block in self.down_blocks:
            if hasattr(block, "attentions") and block.attentions is not None:
                block.set_attention_slice(slice_size)

        self.mid_block.set_attention_slice(slice_size)

        for block in self.up_blocks:
            if hasattr(block, "attentions") and block.attentions is not None:
                block.set_attention_slice(slice_size)

233
234
235
236
237
238
239
240
241
242
243
    def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
        for block in self.down_blocks:
            if hasattr(block, "attentions") and block.attentions is not None:
                block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

        self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

        for block in self.up_blocks:
            if hasattr(block, "attentions") and block.attentions is not None:
                block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

244
245
246
247
    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
            module.gradient_checkpointing = value

Patrick von Platen's avatar
Patrick von Platen committed
248
249
250
251
252
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
253
254
        return_dict: bool = True,
    ) -> Union[UNet2DConditionOutput, Tuple]:
255
        r"""
Kashif Rasul's avatar
Kashif Rasul committed
256
257
        Args:
            sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
258
            timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
Kamal Raj's avatar
Kamal Raj committed
259
260
            encoder_hidden_states (`torch.FloatTensor`):
                (batch_size, sequence_length, hidden_size) encoder hidden states
Kashif Rasul's avatar
Kashif Rasul committed
261
262
263
264
265
266
267
268
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.

        Returns:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
            [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
        """
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        # By default samples have to be AT least a multiple of the overall upsampling factor.
        # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
        # However, the upsampling interpolation output size can be forced to fit any upsampling size
        # on the fly if necessary.
        default_overall_up_factor = 2**self.num_upsamplers

        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
        forward_upsample_size = False
        upsample_size = None

        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
            logger.info("Forward upsample size to force interpolation output size.")
            forward_upsample_size = True

Patrick von Platen's avatar
Patrick von Platen committed
283
284
285
286
287
288
289
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
290
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
Patrick von Platen's avatar
Patrick von Platen committed
291
292
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
293
            timesteps = timesteps[None].to(sample.device)
Patrick von Platen's avatar
Patrick von Platen committed
294

295
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
296
        timesteps = timesteps.expand(sample.shape[0])
297

Patrick von Platen's avatar
Patrick von Platen committed
298
        t_emb = self.time_proj(timesteps)
299
300
301
302
303
304

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=self.dtype)
        emb = self.time_embedding(t_emb)
Patrick von Platen's avatar
Patrick von Platen committed
305
306
307
308
309
310
311
312
313

        # 2. pre-process
        sample = self.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
                sample, res_samples = downsample_block(
314
315
316
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
Patrick von Platen's avatar
Patrick von Platen committed
317
318
319
320
321
322
323
324
325
326
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

        # 4. mid
        sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)

        # 5. up
327
328
329
        for i, upsample_block in enumerate(self.up_blocks):
            is_final_block = i == len(self.up_blocks) - 1

Patrick von Platen's avatar
Patrick von Platen committed
330
331
332
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

333
334
335
336
337
            # if we have not reached the final block and need to forward the
            # upsample size, we do it here
            if not is_final_block and forward_upsample_size:
                upsample_size = down_block_res_samples[-1].shape[2:]

Patrick von Platen's avatar
Patrick von Platen committed
338
339
340
341
342
343
            if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
                sample = upsample_block(
                    hidden_states=sample,
                    temb=emb,
                    res_hidden_states_tuple=res_samples,
                    encoder_hidden_states=encoder_hidden_states,
344
                    upsample_size=upsample_size,
Patrick von Platen's avatar
Patrick von Platen committed
345
346
                )
            else:
347
348
349
                sample = upsample_block(
                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
                )
Patrick von Platen's avatar
Patrick von Platen committed
350
        # 6. post-process
351
        sample = self.conv_norm_out(sample)
Patrick von Platen's avatar
Patrick von Platen committed
352
353
354
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

355
356
        if not return_dict:
            return (sample,)
Patrick von Platen's avatar
Patrick von Platen committed
357

358
        return UNet2DConditionOutput(sample=sample)