unet_2d_blocks_flax.py 13 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
16
17

import flax.linen as nn
import jax.numpy as jnp

Will Berman's avatar
Will Berman committed
18
from .attention_flax import FlaxTransformer2DModel
19
20
21
22
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D


class FlaxCrossAttnDownBlock2D(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    r"""
    Cross Attention 2D Downsizing block - original architecture from Unet transformers:
    https://arxiv.org/abs/2103.06104

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
            Number of attention heads of each spatial transformer block
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add downsampling layer before each final output
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
43
44
45
46
47
48
    in_channels: int
    out_channels: int
    dropout: float = 0.0
    num_layers: int = 1
    attn_num_head_channels: int = 1
    add_downsample: bool = True
49
50
    use_linear_projection: bool = False
    only_cross_attention: bool = False
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        resnets = []
        attentions = []

        for i in range(self.num_layers):
            in_channels = self.in_channels if i == 0 else self.out_channels

            res_block = FlaxResnetBlock2D(
                in_channels=in_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
            )
            resnets.append(res_block)

Will Berman's avatar
Will Berman committed
68
            attn_block = FlaxTransformer2DModel(
69
70
71
72
                in_channels=self.out_channels,
                n_heads=self.attn_num_head_channels,
                d_head=self.out_channels // self.attn_num_head_channels,
                depth=1,
73
74
                use_linear_projection=self.use_linear_projection,
                only_cross_attention=self.only_cross_attention,
75
76
77
78
79
80
81
82
                dtype=self.dtype,
            )
            attentions.append(attn_block)

        self.resnets = resnets
        self.attentions = attentions

        if self.add_downsample:
83
            self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
84
85
86
87
88
89
90
91
92
93

    def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
        output_states = ()

        for resnet, attn in zip(self.resnets, self.attentions):
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
            hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
            output_states += (hidden_states,)

        if self.add_downsample:
94
            hidden_states = self.downsamplers_0(hidden_states)
95
96
97
98
99
100
            output_states += (hidden_states,)

        return hidden_states, output_states


class FlaxDownBlock2D(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    r"""
    Flax 2D downsizing block

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add downsampling layer before each final output
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    in_channels: int
    out_channels: int
    dropout: float = 0.0
    num_layers: int = 1
    add_downsample: bool = True
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        resnets = []

        for i in range(self.num_layers):
            in_channels = self.in_channels if i == 0 else self.out_channels

            res_block = FlaxResnetBlock2D(
                in_channels=in_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
            )
            resnets.append(res_block)
        self.resnets = resnets

        if self.add_downsample:
141
            self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
142
143
144
145
146
147
148
149
150

    def __call__(self, hidden_states, temb, deterministic=True):
        output_states = ()

        for resnet in self.resnets:
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
            output_states += (hidden_states,)

        if self.add_downsample:
151
            hidden_states = self.downsamplers_0(hidden_states)
152
153
154
155
156
157
            output_states += (hidden_states,)

        return hidden_states, output_states


class FlaxCrossAttnUpBlock2D(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    r"""
    Cross Attention 2D Upsampling block - original architecture from Unet transformers:
    https://arxiv.org/abs/2103.06104

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
            Number of attention heads of each spatial transformer block
        add_upsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add upsampling layer before each final output
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
178
179
180
181
182
183
184
    in_channels: int
    out_channels: int
    prev_output_channel: int
    dropout: float = 0.0
    num_layers: int = 1
    attn_num_head_channels: int = 1
    add_upsample: bool = True
185
186
    use_linear_projection: bool = False
    only_cross_attention: bool = False
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        resnets = []
        attentions = []

        for i in range(self.num_layers):
            res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
            resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels

            res_block = FlaxResnetBlock2D(
                in_channels=resnet_in_channels + res_skip_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
            )
            resnets.append(res_block)

Will Berman's avatar
Will Berman committed
205
            attn_block = FlaxTransformer2DModel(
206
207
208
209
                in_channels=self.out_channels,
                n_heads=self.attn_num_head_channels,
                d_head=self.out_channels // self.attn_num_head_channels,
                depth=1,
210
211
                use_linear_projection=self.use_linear_projection,
                only_cross_attention=self.only_cross_attention,
212
213
214
215
216
217
218
219
                dtype=self.dtype,
            )
            attentions.append(attn_block)

        self.resnets = resnets
        self.attentions = attentions

        if self.add_upsample:
220
            self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
221
222
223
224
225
226
227
228
229
230
231
232

    def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
        for resnet, attn in zip(self.resnets, self.attentions):
            # pop res hidden states
            res_hidden_states = res_hidden_states_tuple[-1]
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
            hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)

            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
            hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)

        if self.add_upsample:
233
            hidden_states = self.upsamplers_0(hidden_states)
234
235
236
237
238

        return hidden_states


class FlaxUpBlock2D(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    r"""
    Flax 2D upsampling block

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        prev_output_channel (:obj:`int`):
            Output channels from the previous block
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add downsampling layer before each final output
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    in_channels: int
    out_channels: int
    prev_output_channel: int
    dropout: float = 0.0
    num_layers: int = 1
    add_upsample: bool = True
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        resnets = []

        for i in range(self.num_layers):
            res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
            resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels

            res_block = FlaxResnetBlock2D(
                in_channels=resnet_in_channels + res_skip_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
            )
            resnets.append(res_block)

        self.resnets = resnets

        if self.add_upsample:
284
            self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
285
286
287
288
289
290
291
292
293
294
295

    def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
        for resnet in self.resnets:
            # pop res hidden states
            res_hidden_states = res_hidden_states_tuple[-1]
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
            hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)

            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)

        if self.add_upsample:
296
            hidden_states = self.upsamplers_0(hidden_states)
297
298
299
300
301

        return hidden_states


class FlaxUNetMidBlock2DCrossAttn(nn.Module):
Younes Belkada's avatar
Younes Belkada committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    r"""
    Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
            Number of attention heads of each spatial transformer block
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
317
318
319
320
    in_channels: int
    dropout: float = 0.0
    num_layers: int = 1
    attn_num_head_channels: int = 1
321
    use_linear_projection: bool = False
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        # there is always at least one resnet
        resnets = [
            FlaxResnetBlock2D(
                in_channels=self.in_channels,
                out_channels=self.in_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
            )
        ]

        attentions = []

        for _ in range(self.num_layers):
Will Berman's avatar
Will Berman committed
338
            attn_block = FlaxTransformer2DModel(
339
340
341
342
                in_channels=self.in_channels,
                n_heads=self.attn_num_head_channels,
                d_head=self.in_channels // self.attn_num_head_channels,
                depth=1,
343
                use_linear_projection=self.use_linear_projection,
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
                dtype=self.dtype,
            )
            attentions.append(attn_block)

            res_block = FlaxResnetBlock2D(
                in_channels=self.in_channels,
                out_channels=self.in_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
            )
            resnets.append(res_block)

        self.resnets = resnets
        self.attentions = attentions

    def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
        hidden_states = self.resnets[0](hidden_states, temb)
        for attn, resnet in zip(self.attentions, self.resnets[1:]):
            hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)

        return hidden_states