unet_2d_condition_flax.py 15.8 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from typing import Optional, Tuple, Union
15

16
import flax
17
18
19
20
21
22
23
24
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict

from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..utils import BaseOutput
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
from .modeling_flax_utils import FlaxModelMixin
26
from .unet_2d_blocks_flax import (
27
28
29
30
31
32
33
34
    FlaxCrossAttnDownBlock2D,
    FlaxCrossAttnUpBlock2D,
    FlaxDownBlock2D,
    FlaxUNetMidBlock2DCrossAttn,
    FlaxUpBlock2D,
)


35
@flax.struct.dataclass
36
37
class FlaxUNet2DConditionOutput(BaseOutput):
    """
Steven Liu's avatar
Steven Liu committed
38
39
    The output of [`FlaxUNet2DConditionModel`].

40
41
    Args:
        sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Steven Liu's avatar
Steven Liu committed
42
            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
43
44
45
46
47
48
49
50
    """

    sample: jnp.ndarray


@flax_register_to_config
class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
    r"""
Steven Liu's avatar
Steven Liu committed
51
52
    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
    shaped output.
53

Steven Liu's avatar
Steven Liu committed
54
55
    This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods
    implemented for all models (such as downloading or saving).
56

Steven Liu's avatar
Steven Liu committed
57
58
    This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
    subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its
Younes Belkada's avatar
Younes Belkada committed
59
60
    general usage and behavior.

Steven Liu's avatar
Steven Liu committed
61
    Inherent JAX features such as the following are supported:
Younes Belkada's avatar
Younes Belkada committed
62
63
64
65
66
    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

67
    Parameters:
Younes Belkada's avatar
Younes Belkada committed
68
69
70
71
72
73
        sample_size (`int`, *optional*):
            The size of the input sample.
        in_channels (`int`, *optional*, defaults to 4):
            The number of channels in the input sample.
        out_channels (`int`, *optional*, defaults to 4):
            The number of channels in the output.
Steven Liu's avatar
Steven Liu committed
74
75
76
77
        down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
            The tuple of downsample blocks to use.
        up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
            The tuple of upsample blocks to use.
78
79
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
Younes Belkada's avatar
Younes Belkada committed
80
81
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
82
        attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
Younes Belkada's avatar
Younes Belkada committed
83
            The dimension of the attention heads.
84
85
        num_attention_heads (`int` or `Tuple[int]`, *optional*):
            The number of attention heads.
Younes Belkada's avatar
Younes Belkada committed
86
87
88
89
        cross_attention_dim (`int`, *optional*, defaults to 768):
            The dimension of the cross attention features.
        dropout (`float`, *optional*, defaults to 0):
            Dropout probability for down, up and bottleneck blocks.
Akash Gokul's avatar
Akash Gokul committed
90
91
92
        flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
            Whether to flip the sin to cos in the time embedding.
        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
93
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Steven Liu's avatar
Steven Liu committed
94
            Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
95
96
97
98
99
100
101
102
103
104
105
106
    """

    sample_size: int = 32
    in_channels: int = 4
    out_channels: int = 4
    down_block_types: Tuple[str] = (
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D",
    )
    up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
107
    only_cross_attention: Union[bool, Tuple[bool]] = False
108
109
    block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
    layers_per_block: int = 2
110
    attention_head_dim: Union[int, Tuple[int]] = 8
111
    num_attention_heads: Optional[Union[int, Tuple[int]]] = None
112
113
    cross_attention_dim: int = 1280
    dropout: float = 0.0
114
    use_linear_projection: bool = False
115
    dtype: jnp.dtype = jnp.float32
Akash Gokul's avatar
Akash Gokul committed
116
    flip_sin_to_cos: bool = True
117
    freq_shift: int = 0
118
    use_memory_efficient_attention: bool = False
119

120
    def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
121
        # init input tensors
122
        sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
123
124
125
126
127
128
129
130
131
132
133
134
135
        sample = jnp.zeros(sample_shape, dtype=jnp.float32)
        timesteps = jnp.ones((1,), dtype=jnp.int32)
        encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]

    def setup(self):
        block_out_channels = self.block_out_channels
        time_embed_dim = block_out_channels[0] * 4

136
137
138
139
140
        if self.num_attention_heads is not None:
            raise ValueError(
                "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
            )

141
142
143
144
145
146
147
148
        # If `num_attention_heads` is not defined (which is the case for most models)
        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
        # The reason for this behavior is to correct for incorrectly named variables that were introduced
        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
        # which is why we correct for the naming here.
        num_attention_heads = self.num_attention_heads or self.attention_head_dim

149
150
151
152
153
154
155
156
157
158
        # input
        self.conv_in = nn.Conv(
            block_out_channels[0],
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
            dtype=self.dtype,
        )

        # time
Akash Gokul's avatar
Akash Gokul committed
159
160
161
        self.time_proj = FlaxTimesteps(
            block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
        )
162
163
        self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)

164
165
166
167
        only_cross_attention = self.only_cross_attention
        if isinstance(only_cross_attention, bool):
            only_cross_attention = (only_cross_attention,) * len(self.down_block_types)

168
169
        if isinstance(num_attention_heads, int):
            num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
170

171
172
173
174
175
176
177
178
179
180
181
182
183
184
        # down
        down_blocks = []
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(self.down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            if down_block_type == "CrossAttnDownBlock2D":
                down_block = FlaxCrossAttnDownBlock2D(
                    in_channels=input_channel,
                    out_channels=output_channel,
                    dropout=self.dropout,
                    num_layers=self.layers_per_block,
185
                    num_attention_heads=num_attention_heads[i],
186
                    add_downsample=not is_final_block,
187
188
                    use_linear_projection=self.use_linear_projection,
                    only_cross_attention=only_cross_attention[i],
189
                    use_memory_efficient_attention=self.use_memory_efficient_attention,
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
                    dtype=self.dtype,
                )
            else:
                down_block = FlaxDownBlock2D(
                    in_channels=input_channel,
                    out_channels=output_channel,
                    dropout=self.dropout,
                    num_layers=self.layers_per_block,
                    add_downsample=not is_final_block,
                    dtype=self.dtype,
                )

            down_blocks.append(down_block)
        self.down_blocks = down_blocks

        # mid
        self.mid_block = FlaxUNetMidBlock2DCrossAttn(
            in_channels=block_out_channels[-1],
            dropout=self.dropout,
209
            num_attention_heads=num_attention_heads[-1],
210
            use_linear_projection=self.use_linear_projection,
211
            use_memory_efficient_attention=self.use_memory_efficient_attention,
212
213
214
215
216
217
            dtype=self.dtype,
        )

        # up
        up_blocks = []
        reversed_block_out_channels = list(reversed(block_out_channels))
218
        reversed_num_attention_heads = list(reversed(num_attention_heads))
219
        only_cross_attention = list(reversed(only_cross_attention))
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(self.up_block_types):
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

            is_final_block = i == len(block_out_channels) - 1

            if up_block_type == "CrossAttnUpBlock2D":
                up_block = FlaxCrossAttnUpBlock2D(
                    in_channels=input_channel,
                    out_channels=output_channel,
                    prev_output_channel=prev_output_channel,
                    num_layers=self.layers_per_block + 1,
234
                    num_attention_heads=reversed_num_attention_heads[i],
235
236
                    add_upsample=not is_final_block,
                    dropout=self.dropout,
237
238
                    use_linear_projection=self.use_linear_projection,
                    only_cross_attention=only_cross_attention[i],
239
                    use_memory_efficient_attention=self.use_memory_efficient_attention,
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
                    dtype=self.dtype,
                )
            else:
                up_block = FlaxUpBlock2D(
                    in_channels=input_channel,
                    out_channels=output_channel,
                    prev_output_channel=prev_output_channel,
                    num_layers=self.layers_per_block + 1,
                    add_upsample=not is_final_block,
                    dropout=self.dropout,
                    dtype=self.dtype,
                )

            up_blocks.append(up_block)
            prev_output_channel = output_channel
        self.up_blocks = up_blocks

        # out
        self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
        self.conv_out = nn.Conv(
            self.out_channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
            dtype=self.dtype,
        )

    def __call__(
        self,
        sample,
        timesteps,
        encoder_hidden_states,
YiYi Xu's avatar
YiYi Xu committed
272
273
        down_block_additional_residuals=None,
        mid_block_additional_residual=None,
274
275
276
        return_dict: bool = True,
        train: bool = False,
    ) -> Union[FlaxUNet2DConditionOutput, Tuple]:
277
        r"""
278
        Args:
Kamal Raj's avatar
Kamal Raj committed
279
            sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
280
            timestep (`jnp.ndarray` or `float` or `int`): timesteps
Kamal Raj's avatar
Kamal Raj committed
281
            encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
282
283
284
285
286
287
288
289
290
291
292
293
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
                plain tuple.
            train (`bool`, *optional*, defaults to `False`):
                Use deterministic functions and disable dropout when not training.

        Returns:
            [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
            [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
            When returning a tuple, the first element is the sample tensor.
        """
        # 1. time
294
295
296
297
298
299
        if not isinstance(timesteps, jnp.ndarray):
            timesteps = jnp.array([timesteps], dtype=jnp.int32)
        elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
            timesteps = timesteps.astype(dtype=jnp.float32)
            timesteps = jnp.expand_dims(timesteps, 0)

300
301
302
303
        t_emb = self.time_proj(timesteps)
        t_emb = self.time_embedding(t_emb)

        # 2. pre-process
304
        sample = jnp.transpose(sample, (0, 2, 3, 1))
305
306
307
308
309
310
311
312
313
314
315
        sample = self.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for down_block in self.down_blocks:
            if isinstance(down_block, FlaxCrossAttnDownBlock2D):
                sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
            else:
                sample, res_samples = down_block(sample, t_emb, deterministic=not train)
            down_block_res_samples += res_samples

YiYi Xu's avatar
YiYi Xu committed
316
317
318
319
320
321
322
323
324
325
326
        if down_block_additional_residuals is not None:
            new_down_block_res_samples = ()

            for down_block_res_sample, down_block_additional_residual in zip(
                down_block_res_samples, down_block_additional_residuals
            ):
                down_block_res_sample += down_block_additional_residual
                new_down_block_res_samples += (down_block_res_sample,)

            down_block_res_samples = new_down_block_res_samples

327
328
329
        # 4. mid
        sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)

YiYi Xu's avatar
YiYi Xu committed
330
331
332
        if mid_block_additional_residual is not None:
            sample += mid_block_additional_residual

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        # 5. up
        for up_block in self.up_blocks:
            res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
            down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
            if isinstance(up_block, FlaxCrossAttnUpBlock2D):
                sample = up_block(
                    sample,
                    temb=t_emb,
                    encoder_hidden_states=encoder_hidden_states,
                    res_hidden_states_tuple=res_samples,
                    deterministic=not train,
                )
            else:
                sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train)

        # 6. post-process
        sample = self.conv_norm_out(sample)
        sample = nn.silu(sample)
        sample = self.conv_out(sample)
352
        sample = jnp.transpose(sample, (0, 3, 1, 2))
353
354
355
356
357

        if not return_dict:
            return (sample,)

        return FlaxUNet2DConditionOutput(sample=sample)